Added cross to Tensor + fix benchmark

This commit is contained in:
AdrienBouvais 2026-04-27 15:13:15 +02:00
parent 16d25e7e7e
commit cd954b379b
2 changed files with 81 additions and 36 deletions

View File

@ -17,9 +17,8 @@ pub fn shapeTotal(comptime shape: []const usize) usize {
/// Check if two shapes are strictly identical. /// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
if (a.len != b.len) return false; if (a.len != b.len) return false;
for (a, 0..) |v, i| { for (a, 0..) |v, i|
if (v != b[i]) return false; if (v != b[i]) return false;
}
return true; return true;
} }
@ -432,6 +431,24 @@ pub fn Tensor(
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
if (comptime Self == ActualDest) return self;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT);
// If ratio is 1, just handle type conversion
if (comptime ratio == 1.0) {
if (comptime T == DestT) return .{ .data = self.data };
return .{
.data = switch (comptime @typeInfo(DestT)) {
.float => @floatFromInt(self.data), // or @floatCast
.int => @intFromFloat(self.data), // or @intCast
else => unreachable,
},
};
}
if (comptime !dims.eql(ActualDest.dims)) if (comptime !dims.eql(ActualDest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
@ -439,10 +456,6 @@ pub fn Tensor(
if (comptime Self == ActualDest) return self; if (comptime Self == ActualDest) return self;
const DestT = ActualDest.ValueType;
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
const DestVec = @Vector(total, DestT);
if (comptime T == DestT) { if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float) if (comptime @typeInfo(T) == .float)
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
@ -485,7 +498,7 @@ pub fn Tensor(
const CmpResult = if (total == 1) bool else [total]bool; const CmpResult = if (total == 1) bool else [total]bool;
inline fn cmpResult(v: @Vector(total, bool)) CmpResult { inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
return if (comptime total == 1) v[0] else @as([total]bool, v); return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
} }
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
@ -634,6 +647,40 @@ pub fn Tensor(
return result; return result;
} }
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
/// Result dimensions are the sum of input dimensions.
pub inline fn cross(self: Self, other: anytype) Tensor(
T,
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
&.{3},
) {
const rhs_q = rhs(other);
const RhsType = @TypeOf(rhs_q);
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
}
// Bring both to the same scale (e.g., mm vs m)
const p = self.resolveScalePair(rhs_q);
const l = p.l;
const r = p.r;
var res: [3]T = undefined;
if (comptime isInt(T)) {
res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]);
res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]);
res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]);
} else {
res[0] = (l[1] * r[2]) - (l[2] * r[1]);
res[1] = (l[2] * r[0]) - (l[0] * r[2]);
res[2] = (l[0] * r[1]) - (l[1] * r[0]);
}
return .{ .data = res };
}
/// Sum of squared elements. Cheaper than length(); use for ordering. /// Sum of squared elements. Cheaper than length(); use for ordering.
pub inline fn lengthSqr(self: Self) T { pub inline fn lengthSqr(self: Self) T {
return @reduce(.Add, self.data * self.data); return @reduce(.Add, self.data * self.data);

View File

@ -1,7 +1,6 @@
const std = @import("std"); const std = @import("std");
const Io = std.Io; const Io = std.Io;
const Scalar = @import("Quantity.zig").Scalar; const Tensor = @import("Tensor.zig").Tensor;
const Vector = @import("Quantity.zig").Vector;
var io: Io = undefined; var io: Io = undefined;
pub fn main(init: std.process.Init) !void { pub fn main(init: std.process.Init) !void {
@ -11,23 +10,23 @@ pub fn main(init: std.process.Init) !void {
io = init.io; io = init.io;
// try vectorSIMDvsNative(f64, &stdout_writer.interface); try vectorSIMDvsNative(f64, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(f32, &stdout_writer.interface); try vectorSIMDvsNative(f32, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(i32, &stdout_writer.interface); try vectorSIMDvsNative(i32, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(i64, &stdout_writer.interface); try vectorSIMDvsNative(i64, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
// try vectorSIMDvsNative(i128, &stdout_writer.interface); try vectorSIMDvsNative(i128, &stdout_writer.interface);
// try stdout_writer.flush(); try stdout_writer.flush();
try bench_Scalar(&stdout_writer.interface); try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
try bench_vsNative(&stdout_writer.interface); // try bench_vsNative(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_crossTypeVsNative(&stdout_writer.interface); // try bench_crossTypeVsNative(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface); try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
} }
@ -97,9 +96,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
comptime var tidx: usize = 0; comptime var tidx: usize = 0;
inline for (Types, TNames) |T, tname| { inline for (Types, TNames) |T, tname| {
const M = Scalar(T, .{ .L = 1 }, .{}); const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k }); const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1});
const S = Scalar(T, .{ .T = 1 }, .{}); const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
inline for (Ops, 0..) |op_name, oidx| { inline for (Ops, 0..) |op_name, oidx| {
var samples: [SAMPLES]f64 = undefined; var samples: [SAMPLES]f64 = undefined;
@ -199,8 +198,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
var native_total_ns: f64 = 0; var native_total_ns: f64 = 0;
var quantity_total_ns: f64 = 0; var quantity_total_ns: f64 = 0;
const M = Scalar(T, .{ .L = 1 }, .{}); const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
const S = Scalar(T, .{ .T = 1 }, .{}); const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
std.mem.doNotOptimizeAway({ std.mem.doNotOptimizeAway({
for (0..SAMPLES) |_| { for (0..SAMPLES) |_| {
@ -321,9 +320,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void {
var native_total_ns: f64 = 0; var native_total_ns: f64 = 0;
var quantity_total_ns: f64 = 0; var quantity_total_ns: f64 = 0;
const M1 = Scalar(T1, .{ .L = 1 }, .{}); const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1});
const M2 = Scalar(T2, .{ .L = 1 }, .{}); const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1});
const S2 = Scalar(T2, .{ .T = 1 }, .{}); const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1});
std.mem.doNotOptimizeAway({ std.mem.doNotOptimizeAway({
for (0..SAMPLES) |_| { for (0..SAMPLES) |_| {
@ -429,9 +428,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname }); try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname });
inline for (Lengths) |len| { inline for (Lengths) |len| {
const Q_base = Scalar(T, .{ .L = 1 }, .{}); const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1});
const Q_time = Scalar(T, .{ .T = 1 }, .{}); const V = Tensor(T, .{ .L = 1 }, .{}, &.{len});
const V = Vector(len, Q_base);
// cross product is only defined for len == 3 // cross product is only defined for len == 3
const is_cross = comptime std.mem.eql(u8, op_name, "cross"); const is_cross = comptime std.mem.eql(u8, op_name, "cross");
@ -455,10 +453,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
_ = v1.div(V.splat(getVal(T, i +% 2, 63))); _ = v1.div(V.splat(getVal(T, i +% 2, 63)));
} else if (comptime std.mem.eql(u8, op_name, "mulScalar")) { } else if (comptime std.mem.eql(u8, op_name, "mulScalar")) {
const s_val = Q_time.splat(getVal(T, i +% 2, 63)); const s_val = Q_time.splat(getVal(T, i +% 2, 63));
_ = v1.mulScalar(s_val); _ = v1.mul(s_val);
} else if (comptime std.mem.eql(u8, op_name, "dot")) { } else if (comptime std.mem.eql(u8, op_name, "dot")) {
const v2 = V.splat(getVal(T, i +% 5, 63)); const v2 = V.splat(getVal(T, i +% 5, 63));
_ = v1.dot(v2); _ = v1.contract(v2, 0, 0);
} else if (comptime std.mem.eql(u8, op_name, "cross")) { } else if (comptime std.mem.eql(u8, op_name, "cross")) {
// len == 3 guaranteed by the guard above // len == 3 guaranteed by the guard above
const v2 = V.splat(getVal(T, i +% 5, 63)); const v2 = V.splat(getVal(T, i +% 5, 63));