From cd954b379b7eedd623d4928b929626427ed121f6 Mon Sep 17 00:00:00 2001 From: AdrienBouvais Date: Mon, 27 Apr 2026 15:13:15 +0200 Subject: [PATCH] Added cross to Tensor + fix benchmark --- src/Tensor.zig | 61 +++++++++++++++++++++++++++++++++++++++++------ src/benchmark.zig | 56 +++++++++++++++++++++---------------------- 2 files changed, 81 insertions(+), 36 deletions(-) diff --git a/src/Tensor.zig b/src/Tensor.zig index 6cc93ab..ff97dc0 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -17,9 +17,8 @@ pub fn shapeTotal(comptime shape: []const usize) usize { /// Check if two shapes are strictly identical. pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { if (a.len != b.len) return false; - for (a, 0..) |v, i| { + for (a, 0..) |v, i| if (v != b[i]) return false; - } return true; } @@ -432,6 +431,24 @@ pub fn Tensor( ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); + if (comptime Self == ActualDest) return self; + + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestT = ActualDest.ValueType; + const DestVec = @Vector(total, DestT); + + // If ratio is 1, just handle type conversion + if (comptime ratio == 1.0) { + if (comptime T == DestT) return .{ .data = self.data }; + return .{ + .data = switch (comptime @typeInfo(DestT)) { + .float => @floatFromInt(self.data), // or @floatCast + .int => @intFromFloat(self.data), // or @intCast + else => unreachable, + }, + }; + } + if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) @@ -439,10 +456,6 @@ pub fn Tensor( if (comptime Self == ActualDest) return self; - const DestT = ActualDest.ValueType; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestVec = @Vector(total, DestT); - if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; @@ -485,7 +498,7 @@ pub fn Tensor( const CmpResult = if (total == 1) bool else [total]bool; inline fn cmpResult(v: @Vector(total, bool)) CmpResult { - return if (comptime total == 1) v[0] else @as([total]bool, v); + return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v); } /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. @@ -634,6 +647,40 @@ pub fn Tensor( return result; } + /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. + /// Result dimensions are the sum of input dimensions. + pub inline fn cross(self: Self, other: anytype) Tensor( + T, + dims.add(RhsT(@TypeOf(other)).dims).argsOpt(), + finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), + &.{3}, + ) { + const rhs_q = rhs(other); + const RhsType = @TypeOf(rhs_q); + + if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) { + @compileError("cross product is only defined for 3D vectors (rank-1, length 3)"); + } + + // Bring both to the same scale (e.g., mm vs m) + const p = self.resolveScalePair(rhs_q); + const l = p.l; + const r = p.r; + + var res: [3]T = undefined; + if (comptime isInt(T)) { + res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]); + res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]); + res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]); + } else { + res[0] = (l[1] * r[2]) - (l[2] * r[1]); + res[1] = (l[2] * r[0]) - (l[0] * r[2]); + res[2] = (l[0] * r[1]) - (l[1] * r[0]); + } + + return .{ .data = res }; + } + /// Sum of squared elements. Cheaper than length(); use for ordering. pub inline fn lengthSqr(self: Self) T { return @reduce(.Add, self.data * self.data); diff --git a/src/benchmark.zig b/src/benchmark.zig index 7da3e6f..c50b3a9 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -1,7 +1,6 @@ const std = @import("std"); const Io = std.Io; -const Scalar = @import("Quantity.zig").Scalar; -const Vector = @import("Quantity.zig").Vector; +const Tensor = @import("Tensor.zig").Tensor; var io: Io = undefined; pub fn main(init: std.process.Init) !void { @@ -11,23 +10,23 @@ pub fn main(init: std.process.Init) !void { io = init.io; - // try vectorSIMDvsNative(f64, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(f32, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i32, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i64, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i128, &stdout_writer.interface); - // try stdout_writer.flush(); + try vectorSIMDvsNative(f64, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(f32, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i32, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i64, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i128, &stdout_writer.interface); + try stdout_writer.flush(); try bench_Scalar(&stdout_writer.interface); try stdout_writer.flush(); - try bench_vsNative(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_crossTypeVsNative(&stdout_writer.interface); - try stdout_writer.flush(); + // try bench_vsNative(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_crossTypeVsNative(&stdout_writer.interface); + // try stdout_writer.flush(); try bench_Vector(&stdout_writer.interface); try stdout_writer.flush(); } @@ -97,9 +96,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { comptime var tidx: usize = 0; inline for (Types, TNames) |T, tname| { - const M = Scalar(T, .{ .L = 1 }, .{}); - const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k }); - const S = Scalar(T, .{ .T = 1 }, .{}); + const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); + const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); inline for (Ops, 0..) |op_name, oidx| { var samples: [SAMPLES]f64 = undefined; @@ -199,8 +198,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M = Scalar(T, .{ .L = 1 }, .{}); - const S = Scalar(T, .{ .T = 1 }, .{}); + const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); + const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { @@ -321,9 +320,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M1 = Scalar(T1, .{ .L = 1 }, .{}); - const M2 = Scalar(T2, .{ .L = 1 }, .{}); - const S2 = Scalar(T2, .{ .T = 1 }, .{}); + const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1}); + const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1}); + const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { @@ -429,9 +428,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void { try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname }); inline for (Lengths) |len| { - const Q_base = Scalar(T, .{ .L = 1 }, .{}); - const Q_time = Scalar(T, .{ .T = 1 }, .{}); - const V = Vector(len, Q_base); + const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1}); + const V = Tensor(T, .{ .L = 1 }, .{}, &.{len}); // cross product is only defined for len == 3 const is_cross = comptime std.mem.eql(u8, op_name, "cross"); @@ -455,10 +453,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void { _ = v1.div(V.splat(getVal(T, i +% 2, 63))); } else if (comptime std.mem.eql(u8, op_name, "mulScalar")) { const s_val = Q_time.splat(getVal(T, i +% 2, 63)); - _ = v1.mulScalar(s_val); + _ = v1.mul(s_val); } else if (comptime std.mem.eql(u8, op_name, "dot")) { const v2 = V.splat(getVal(T, i +% 5, 63)); - _ = v1.dot(v2); + _ = v1.contract(v2, 0, 0); } else if (comptime std.mem.eql(u8, op_name, "cross")) { // len == 3 guaranteed by the guard above const v2 = V.splat(getVal(T, i +% 5, 63));