Added cross to Tensor + fix benchmark
This commit is contained in:
parent
16d25e7e7e
commit
cd954b379b
@ -17,9 +17,8 @@ pub fn shapeTotal(comptime shape: []const usize) usize {
|
||||
/// Check if two shapes are strictly identical.
|
||||
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
||||
if (a.len != b.len) return false;
|
||||
for (a, 0..) |v, i| {
|
||||
for (a, 0..) |v, i|
|
||||
if (v != b[i]) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -432,6 +431,24 @@ pub fn Tensor(
|
||||
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
||||
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
||||
|
||||
if (comptime Self == ActualDest) return self;
|
||||
|
||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||
const DestT = ActualDest.ValueType;
|
||||
const DestVec = @Vector(total, DestT);
|
||||
|
||||
// If ratio is 1, just handle type conversion
|
||||
if (comptime ratio == 1.0) {
|
||||
if (comptime T == DestT) return .{ .data = self.data };
|
||||
return .{
|
||||
.data = switch (comptime @typeInfo(DestT)) {
|
||||
.float => @floatFromInt(self.data), // or @floatCast
|
||||
.int => @intFromFloat(self.data), // or @intCast
|
||||
else => unreachable,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (comptime !dims.eql(ActualDest.dims))
|
||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
||||
@ -439,10 +456,6 @@ pub fn Tensor(
|
||||
|
||||
if (comptime Self == ActualDest) return self;
|
||||
|
||||
const DestT = ActualDest.ValueType;
|
||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||
const DestVec = @Vector(total, DestT);
|
||||
|
||||
if (comptime T == DestT) {
|
||||
if (comptime @typeInfo(T) == .float)
|
||||
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
|
||||
@ -485,7 +498,7 @@ pub fn Tensor(
|
||||
const CmpResult = if (total == 1) bool else [total]bool;
|
||||
|
||||
inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
|
||||
return if (comptime total == 1) v[0] else @as([total]bool, v);
|
||||
return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
|
||||
}
|
||||
|
||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||
@ -634,6 +647,40 @@ pub fn Tensor(
|
||||
return result;
|
||||
}
|
||||
|
||||
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
||||
/// Result dimensions are the sum of input dimensions.
|
||||
pub inline fn cross(self: Self, other: anytype) Tensor(
|
||||
T,
|
||||
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
|
||||
finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
|
||||
&.{3},
|
||||
) {
|
||||
const rhs_q = rhs(other);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
|
||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
|
||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||
}
|
||||
|
||||
// Bring both to the same scale (e.g., mm vs m)
|
||||
const p = self.resolveScalePair(rhs_q);
|
||||
const l = p.l;
|
||||
const r = p.r;
|
||||
|
||||
var res: [3]T = undefined;
|
||||
if (comptime isInt(T)) {
|
||||
res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]);
|
||||
res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]);
|
||||
res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]);
|
||||
} else {
|
||||
res[0] = (l[1] * r[2]) - (l[2] * r[1]);
|
||||
res[1] = (l[2] * r[0]) - (l[0] * r[2]);
|
||||
res[2] = (l[0] * r[1]) - (l[1] * r[0]);
|
||||
}
|
||||
|
||||
return .{ .data = res };
|
||||
}
|
||||
|
||||
/// Sum of squared elements. Cheaper than length(); use for ordering.
|
||||
pub inline fn lengthSqr(self: Self) T {
|
||||
return @reduce(.Add, self.data * self.data);
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
const std = @import("std");
|
||||
const Io = std.Io;
|
||||
const Scalar = @import("Quantity.zig").Scalar;
|
||||
const Vector = @import("Quantity.zig").Vector;
|
||||
const Tensor = @import("Tensor.zig").Tensor;
|
||||
|
||||
var io: Io = undefined;
|
||||
pub fn main(init: std.process.Init) !void {
|
||||
@ -11,23 +10,23 @@ pub fn main(init: std.process.Init) !void {
|
||||
|
||||
io = init.io;
|
||||
|
||||
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
|
||||
try bench_Scalar(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_vsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
// try bench_vsNative(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
try bench_Vector(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
}
|
||||
@ -97,9 +96,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
|
||||
|
||||
comptime var tidx: usize = 0;
|
||||
inline for (Types, TNames) |T, tname| {
|
||||
const M = Scalar(T, .{ .L = 1 }, .{});
|
||||
const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k });
|
||||
const S = Scalar(T, .{ .T = 1 }, .{});
|
||||
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
|
||||
const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||
|
||||
inline for (Ops, 0..) |op_name, oidx| {
|
||||
var samples: [SAMPLES]f64 = undefined;
|
||||
@ -199,8 +198,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
||||
var native_total_ns: f64 = 0;
|
||||
var quantity_total_ns: f64 = 0;
|
||||
|
||||
const M = Scalar(T, .{ .L = 1 }, .{});
|
||||
const S = Scalar(T, .{ .T = 1 }, .{});
|
||||
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
|
||||
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||
|
||||
std.mem.doNotOptimizeAway({
|
||||
for (0..SAMPLES) |_| {
|
||||
@ -321,9 +320,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void {
|
||||
var native_total_ns: f64 = 0;
|
||||
var quantity_total_ns: f64 = 0;
|
||||
|
||||
const M1 = Scalar(T1, .{ .L = 1 }, .{});
|
||||
const M2 = Scalar(T2, .{ .L = 1 }, .{});
|
||||
const S2 = Scalar(T2, .{ .T = 1 }, .{});
|
||||
const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1});
|
||||
const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1});
|
||||
const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1});
|
||||
|
||||
std.mem.doNotOptimizeAway({
|
||||
for (0..SAMPLES) |_| {
|
||||
@ -429,9 +428,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
|
||||
try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname });
|
||||
|
||||
inline for (Lengths) |len| {
|
||||
const Q_base = Scalar(T, .{ .L = 1 }, .{});
|
||||
const Q_time = Scalar(T, .{ .T = 1 }, .{});
|
||||
const V = Vector(len, Q_base);
|
||||
const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||
const V = Tensor(T, .{ .L = 1 }, .{}, &.{len});
|
||||
|
||||
// cross product is only defined for len == 3
|
||||
const is_cross = comptime std.mem.eql(u8, op_name, "cross");
|
||||
@ -455,10 +453,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
|
||||
_ = v1.div(V.splat(getVal(T, i +% 2, 63)));
|
||||
} else if (comptime std.mem.eql(u8, op_name, "mulScalar")) {
|
||||
const s_val = Q_time.splat(getVal(T, i +% 2, 63));
|
||||
_ = v1.mulScalar(s_val);
|
||||
_ = v1.mul(s_val);
|
||||
} else if (comptime std.mem.eql(u8, op_name, "dot")) {
|
||||
const v2 = V.splat(getVal(T, i +% 5, 63));
|
||||
_ = v1.dot(v2);
|
||||
_ = v1.contract(v2, 0, 0);
|
||||
} else if (comptime std.mem.eql(u8, op_name, "cross")) {
|
||||
// len == 3 guaranteed by the guard above
|
||||
const v2 = V.splat(getVal(T, i +% 5, 63));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user