Removed the feature where you can use comptime int or float ar rhs for operation
This commit is contained in:
parent
7844aacfce
commit
4595397e70
329
src/Tensor.zig
329
src/Tensor.zig
@ -13,39 +13,6 @@ inline fn isTensor(comptime Rhs: type) bool {
|
||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
||||
}
|
||||
|
||||
inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
||||
if (comptime isTensor(Rhs)) return Rhs;
|
||||
return TensorStatic(T, .{}, .{}, &.{1});
|
||||
}
|
||||
|
||||
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
||||
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
||||
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
||||
const is_ptr = @typeInfo(@TypeOf(r)) == .pointer;
|
||||
const Rhs = @TypeOf(if (is_ptr) r.* else r);
|
||||
if (comptime isTensor(Rhs)) return if (is_ptr) r.* else r;
|
||||
const scalar: T = switch (@typeInfo(Rhs)) {
|
||||
.comptime_int => switch (comptime @typeInfo(T)) {
|
||||
.float => @as(T, @floatFromInt(r)),
|
||||
else => @as(T, r),
|
||||
},
|
||||
.comptime_float => switch (comptime @typeInfo(T)) {
|
||||
.int => @as(T, @intFromFloat(r)),
|
||||
else => @as(T, r),
|
||||
},
|
||||
.int => switch (comptime @typeInfo(T)) {
|
||||
.float => @floatFromInt(r),
|
||||
else => @intCast(r),
|
||||
},
|
||||
.float => switch (comptime @typeInfo(T)) {
|
||||
.int => @intFromFloat(r),
|
||||
else => @floatCast(r),
|
||||
},
|
||||
else => @compileError("Unsupported RHS type: " ++ @typeName(Rhs)),
|
||||
};
|
||||
return TensorStatic(T, .{}, .{}, &.{1}){ .data = .{scalar} };
|
||||
}
|
||||
|
||||
/// SIMD implementation of a Tensor.
|
||||
/// Limited to tensor of ~2000 values.
|
||||
/// For more, see either TensorAlloc or TensorGPU
|
||||
@ -110,123 +77,86 @@ pub fn TensorStatic(
|
||||
return @as([*]T, @ptrCast(&self.data))[0..total];
|
||||
}
|
||||
|
||||
inline fn RhsT(comptime Rhs: type) type {
|
||||
return RhsTensorType(T, Rhs);
|
||||
}
|
||||
inline fn rhs(r: anytype) RhsT(@TypeOf(r)) {
|
||||
return toRhsTensor(T, r);
|
||||
}
|
||||
|
||||
inline fn broadcastToVec(comptime RhsType: type, r: RhsType) Vec {
|
||||
return if (comptime RhsType.total == 1 and total > 1)
|
||||
@splat(r.data[0])
|
||||
else
|
||||
r.data;
|
||||
}
|
||||
|
||||
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||
pub inline fn add(self: *const Self, r: anytype) TensorStatic(
|
||||
pub inline fn add(self: *const Self, rhs: anytype) TensorStatic(
|
||||
T,
|
||||
dims.argsOpt(),
|
||||
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
shape_,
|
||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||
shape,
|
||||
) {
|
||||
const rhs_t = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_t);
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
||||
return .{ .data = if (comptime sh.isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data };
|
||||
|
||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
const RhsNorm = TensorStatic(
|
||||
T,
|
||||
RhsType.dims.argsOpt(),
|
||||
sh.finerScales(Self, RhsType).argsOpt(),
|
||||
RhsType.shape,
|
||||
);
|
||||
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
||||
break :blk broadcastToVec(RhsNorm, rn);
|
||||
};
|
||||
return .{ .data = if (comptime sh.isInt(T)) l +| rr else l + rr };
|
||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const l: Vec = self.to(TargetType).data;
|
||||
const r: Vec = rhs.to(TargetType).data;
|
||||
return .{ .data = if (comptime sh.isInt(T)) l +| r else l + r };
|
||||
}
|
||||
|
||||
/// Element-wise sub. Dimensions must match; scales resolve to finer.
|
||||
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||
pub inline fn sub(self: *const Self, r: anytype) TensorStatic(
|
||||
pub inline fn sub(self: *const Self, rhs: anytype) TensorStatic(
|
||||
T,
|
||||
dims.argsOpt(),
|
||||
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
shape_,
|
||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||
shape,
|
||||
) {
|
||||
const rhs_t = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_t);
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime !isTensor(RhsType))
|
||||
@compileError("rhs can only be a Tensor ");
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too
|
||||
return .{ .data = if (comptime sh.isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data };
|
||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm);
|
||||
break :blk broadcastToVec(RhsNorm, rn);
|
||||
};
|
||||
return .{ .data = if (comptime sh.isInt(T)) l -| rr else l - rr };
|
||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const l: Vec = self.to(TargetType).data;
|
||||
const r: Vec = rhs.to(TargetType).data;
|
||||
return .{ .data = if (comptime sh.isInt(T)) l -| r else l - r };
|
||||
}
|
||||
|
||||
/// Element-wise multiply. Dimension exponents summed.
|
||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||
pub inline fn mul(self: *const Self, r: anytype) TensorStatic(
|
||||
pub inline fn mul(self: *const Self, rhs: anytype) TensorStatic(
|
||||
T,
|
||||
dims.add(RhsT(@TypeOf(r)).dims).argsOpt(),
|
||||
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
dims.add(@TypeOf(rhs).dims).argsOpt(),
|
||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||
return .{ .data = if (comptime sh.isInt(T)) l *| rr else l * rr };
|
||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const l: Vec = self.to(SelfNorm).data;
|
||||
const r: Vec = rhs.to(RhsNorm).data;
|
||||
return .{ .data = if (comptime sh.isInt(T)) l *| r else l * r };
|
||||
}
|
||||
|
||||
/// Element-wise divide. Dimension exponents subtracted.
|
||||
/// Shape {1} RHS is automatically broadcast across all elements.
|
||||
pub inline fn div(self: *const Self, r: anytype) TensorStatic(
|
||||
pub inline fn div(self: *const Self, rhs: anytype) TensorStatic(
|
||||
T,
|
||||
dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(),
|
||||
sh.finerScales(Self, RhsT(@TypeOf(r))).argsOpt(),
|
||||
dims.sub(@TypeOf(rhs).dims).argsOpt(),
|
||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||
const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
const rr: Vec = broadcastToVec(RhsNorm, rr_base);
|
||||
if (comptime sh.isInt(T)) {
|
||||
return .{ .data = @divTrunc(l, rr) };
|
||||
} else {
|
||||
return .{ .data = l / rr };
|
||||
}
|
||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||
const l: Vec = self.to(SelfNorm).data;
|
||||
const r: Vec = rhs.to(RhsNorm).data;
|
||||
return .{ .data = if (comptime sh.isInt(T)) @divTrunc(l, r) else l / r };
|
||||
}
|
||||
|
||||
/// Absolute value of every element.
|
||||
@ -294,28 +224,31 @@ pub fn TensorStatic(
|
||||
|
||||
/// Convert to a compatible Tensor type.
|
||||
/// • Dimension mismatch → compile error.
|
||||
/// • Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern).
|
||||
/// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
|
||||
/// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
|
||||
pub inline fn to(
|
||||
self: *const Self,
|
||||
comptime Dest: type,
|
||||
) TensorStatic(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
||||
const ActualDest = TensorStatic(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
||||
|
||||
if (comptime Self == ActualDest) return self;
|
||||
) Dest {
|
||||
if (comptime Self == Dest) return self.*;
|
||||
|
||||
// Run validation checks FIRST before dealing with types
|
||||
if (comptime !dims.eql(ActualDest.dims))
|
||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||
if (comptime Dest.total != 1 and !sh.shapeEql(shape_, Dest.shape))
|
||||
if (comptime !dims.eql(Dest.dims))
|
||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ Dest.dims.str());
|
||||
if (comptime total != 1 and !sh.shapeEql(shape_, Dest.shape))
|
||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||
|
||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||
const DestT = ActualDest.ValueType;
|
||||
const DestVec = @Vector(total, DestT);
|
||||
const vec = if (comptime total == 1 and Dest.total != 1)
|
||||
TensorStatic(Dest.ValueType, dims.argsOpt(), scales.argsOpt(), Dest.shape){ .data = @splat(self.data[0]) }
|
||||
else
|
||||
self;
|
||||
|
||||
const ratio = comptime (scales.getFactor(dims) / Dest.scales.getFactor(Dest.dims));
|
||||
const DestT = Dest.ValueType;
|
||||
const DestVec = @Vector(Dest.total, DestT);
|
||||
|
||||
if (comptime ratio == 1.0 and T == DestT)
|
||||
return .{ .data = self.data };
|
||||
return .{ .data = vec.data };
|
||||
|
||||
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
||||
if (comptime ratio == 1.0) {
|
||||
@ -324,13 +257,13 @@ pub fn TensorStatic(
|
||||
|
||||
return .{
|
||||
.data = if (comptime T_info == .int and Dest_info == .int)
|
||||
@as(DestVec, @intCast(self.data))
|
||||
@as(DestVec, @intCast(vec.data))
|
||||
else if (comptime T_info == .float and Dest_info == .float)
|
||||
@as(DestVec, @floatCast(self.data))
|
||||
@as(DestVec, @floatCast(vec.data))
|
||||
else if (comptime T_info == .int and Dest_info == .float)
|
||||
@as(DestVec, @floatFromInt(self.data))
|
||||
@as(DestVec, @floatFromInt(vec.data))
|
||||
else if (comptime T_info == .float and Dest_info == .int)
|
||||
@as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding
|
||||
@as(DestVec, @intFromFloat(vec.data))
|
||||
else
|
||||
unreachable,
|
||||
};
|
||||
@ -338,22 +271,22 @@ pub fn TensorStatic(
|
||||
|
||||
if (comptime T == DestT) {
|
||||
if (comptime @typeInfo(T) == .float)
|
||||
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
|
||||
return .{ .data = vec.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
|
||||
|
||||
if (comptime ratio >= 1.0) {
|
||||
const mult: T = comptime @intFromFloat(@round(ratio));
|
||||
return .{ .data = self.data *| @as(Vec, @splat(mult)) };
|
||||
return .{ .data = vec.data *| @as(Vec, @splat(mult)) };
|
||||
} else {
|
||||
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
|
||||
const half: T = comptime @divTrunc(div_val, 2);
|
||||
|
||||
if (comptime @typeInfo(T).int.signedness == .unsigned) {
|
||||
return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
|
||||
return .{ .data = @divTrunc(vec.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) };
|
||||
} else {
|
||||
// Vectorized branchless negative handling
|
||||
const is_pos = self.data >= @as(Vec, @splat(0));
|
||||
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
|
||||
return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) };
|
||||
return .{ .data = @divTrunc(vec.data + offsets, @as(Vec, @splat(div_val))) };
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -361,8 +294,8 @@ pub fn TensorStatic(
|
||||
// Cross-type fully vectorized casting with scales
|
||||
const FVec = @Vector(total, f64);
|
||||
const float_vec: FVec = switch (comptime @typeInfo(T)) {
|
||||
.float => @floatCast(self.data),
|
||||
.int => @floatFromInt(self.data),
|
||||
.float => @floatCast(vec.data),
|
||||
.int => @floatFromInt(vec.data),
|
||||
else => unreachable,
|
||||
};
|
||||
|
||||
@ -382,66 +315,54 @@ pub fn TensorStatic(
|
||||
}
|
||||
|
||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||
inline fn resolveScalePair(self: *const Self, rhs_q: anytype) struct { l: Vec, r: Vec } {
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
|
||||
const RhsType = @TypeOf(rhs);
|
||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
|
||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
const rr: Vec = blk: {
|
||||
const RhsNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||
const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm);
|
||||
break :blk broadcastToVec(RhsNorm, rn);
|
||||
};
|
||||
return .{ .l = l, .r = rr };
|
||||
return .{ .l = self.to(TargetType).data, .r = rhs.to(TargetType).data };
|
||||
}
|
||||
|
||||
pub inline fn eq(self: *const Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
@compileError("Dimension mismatch in ne.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
pub inline fn eq(self: *const Self, rhs: anytype) CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in eq.");
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l == p.r);
|
||||
}
|
||||
|
||||
pub inline fn ne(self: *const Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
pub inline fn ne(self: *const Self, rhs: anytype) CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in ne.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l != p.r);
|
||||
}
|
||||
|
||||
pub inline fn gt(self: *const Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
pub inline fn gt(self: *const Self, rhs: anytype) CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in gt.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l > p.r);
|
||||
}
|
||||
|
||||
pub inline fn gte(self: *const Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
pub inline fn gte(self: *const Self, rhs: anytype) CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in gte.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l >= p.r);
|
||||
}
|
||||
|
||||
pub inline fn lt(self: *const Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
pub inline fn lt(self: *const Self, rhs: anytype) CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in lt.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l < p.r);
|
||||
}
|
||||
|
||||
pub inline fn lte(self: *const Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
pub inline fn lte(self: *const Self, rhs: anytype) CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in lte.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l <= p.r);
|
||||
}
|
||||
|
||||
@ -580,21 +501,20 @@ pub fn TensorStatic(
|
||||
|
||||
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
||||
/// Result dimensions are the sum of input dimensions.
|
||||
pub inline fn cross(self: *const Self, other: anytype) TensorStatic(
|
||||
pub inline fn cross(self: *const Self, rhs: anytype) TensorStatic(
|
||||
T,
|
||||
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
|
||||
sh.finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
|
||||
dims.add(@TypeOf(rhs).dims).argsOpt(),
|
||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||
&.{3},
|
||||
) {
|
||||
const rhs_q = rhs(other);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
const RhsType = @TypeOf(rhs);
|
||||
|
||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
|
||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||
}
|
||||
|
||||
// Bring both to the same scale (e.g., mm vs m)
|
||||
const p = self.resolveScalePair(rhs_q);
|
||||
const p = self.resolveScalePair(rhs);
|
||||
const l = p.l;
|
||||
const r = p.r;
|
||||
|
||||
@ -937,17 +857,11 @@ test "Scalar Pow" {
|
||||
try std.testing.expectEqual(64, d.pow(3).data[0]);
|
||||
}
|
||||
|
||||
test "Scalar mul comptime_int" {
|
||||
const Meter = TensorStatic(i128, .{ .L = 1 }, .{}, &.{1});
|
||||
const d = Meter.splat(7);
|
||||
try std.testing.expectEqual(21, d.mul(3).data[0]);
|
||||
}
|
||||
|
||||
test "Scalar add/sub bare number on dimensionless scalar" {
|
||||
const DimLess = TensorStatic(i128, .{}, .{}, &.{1});
|
||||
const a = DimLess.splat(10);
|
||||
try std.testing.expectEqual(15, a.add(5).data[0]);
|
||||
try std.testing.expectEqual(7, a.sub(3).data[0]);
|
||||
try std.testing.expectEqual(15, a.add(DimLess.splat(5)).data[0]);
|
||||
try std.testing.expectEqual(7, a.sub(DimLess.splat(3)).data[0]);
|
||||
}
|
||||
|
||||
test "Scalar Imperial length scales" {
|
||||
@ -967,13 +881,6 @@ test "Scalar Imperial mass scales" {
|
||||
try std.testing.expectApproxEqAbs(2.5, total.data[0], 1e-6);
|
||||
}
|
||||
|
||||
test "Scalar comparisons with comptime_int on dimensionless scalar" {
|
||||
const DimLess = TensorStatic(i128, .{}, .{}, &.{1});
|
||||
const x = DimLess.splat(42);
|
||||
try std.testing.expect(x.eq(42));
|
||||
try std.testing.expect(x.gt(10));
|
||||
}
|
||||
|
||||
// ─── Vector / Tensor tests ────────────────────────────────────────────────
|
||||
|
||||
test "Vector initiate" {
|
||||
@ -1180,58 +1087,14 @@ test "Vector Abs, Pow, Sqrt and Product" {
|
||||
try std.testing.expectEqual(1, @TypeOf(sqrted).dims.get(.L));
|
||||
}
|
||||
|
||||
test "Vector mul comptime_int broadcast" {
|
||||
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||
const v = Meter3{ .data = .{ 1, 2, 3 } };
|
||||
const scaled = v.mul(10);
|
||||
try std.testing.expectEqual(10, scaled.data[0]);
|
||||
try std.testing.expectEqual(20, scaled.data[1]);
|
||||
try std.testing.expectEqual(30, scaled.data[2]);
|
||||
try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L));
|
||||
}
|
||||
|
||||
test "Vector mul comptime_float broadcast" {
|
||||
const MeterF3 = TensorStatic(f32, .{ .L = 1 }, .{}, &.{3});
|
||||
const v = MeterF3{ .data = .{ 1.0, 2.0, 4.0 } };
|
||||
const scaled = v.mul(0.5);
|
||||
try std.testing.expectApproxEqAbs(0.5, scaled.data[0], 1e-6);
|
||||
try std.testing.expectApproxEqAbs(1.0, scaled.data[1], 1e-6);
|
||||
try std.testing.expectApproxEqAbs(2.0, scaled.data[2], 1e-6);
|
||||
try std.testing.expectEqual(1, @TypeOf(scaled).dims.get(.L));
|
||||
}
|
||||
|
||||
test "Vector div comptime_int broadcast" {
|
||||
const Meter3 = TensorStatic(i32, .{ .L = 1 }, .{}, &.{3});
|
||||
const v = Meter3{ .data = .{ 10, 20, 30 } };
|
||||
const halved = v.div(2);
|
||||
try std.testing.expectEqual(5, halved.data[0]);
|
||||
try std.testing.expectEqual(10, halved.data[1]);
|
||||
try std.testing.expectEqual(15, halved.data[2]);
|
||||
try std.testing.expectEqual(1, @TypeOf(halved).dims.get(.L));
|
||||
}
|
||||
|
||||
test "Vector div comptime_float broadcast" {
|
||||
const MeterF3 = TensorStatic(f64, .{ .L = 1 }, .{}, &.{3});
|
||||
const v = MeterF3{ .data = .{ 9.0, 6.0, 3.0 } };
|
||||
const r = v.div(3.0);
|
||||
try std.testing.expectApproxEqAbs(3.0, r.data[0], 1e-9);
|
||||
try std.testing.expectApproxEqAbs(2.0, r.data[1], 1e-9);
|
||||
try std.testing.expectApproxEqAbs(1.0, r.data[2], 1e-9);
|
||||
}
|
||||
|
||||
test "Vector eq broadcast on dimensionless" {
|
||||
const DimLess3 = TensorStatic(i32, .{}, .{}, &.{3});
|
||||
const v = DimLess3{ .data = .{ 1, 2, 3 } };
|
||||
|
||||
const eq_res = v.eq(2);
|
||||
const eq_res = v.eq(DimLess3.splat(2));
|
||||
try std.testing.expectEqual(false, eq_res[0]);
|
||||
try std.testing.expectEqual(true, eq_res[1]);
|
||||
try std.testing.expectEqual(false, eq_res[2]);
|
||||
|
||||
const gt_res = v.gt(1);
|
||||
try std.testing.expectEqual(false, gt_res[0]);
|
||||
try std.testing.expectEqual(true, gt_res[1]);
|
||||
try std.testing.expectEqual(true, gt_res[2]);
|
||||
}
|
||||
|
||||
test "Tensor idx helper and matrix access" {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user