Added cross to Tensor + fix benchmark
This commit is contained in:
parent
16d25e7e7e
commit
cd954b379b
@ -17,9 +17,8 @@ pub fn shapeTotal(comptime shape: []const usize) usize {
|
|||||||
/// Check if two shapes are strictly identical.
|
/// Check if two shapes are strictly identical.
|
||||||
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool {
|
||||||
if (a.len != b.len) return false;
|
if (a.len != b.len) return false;
|
||||||
for (a, 0..) |v, i| {
|
for (a, 0..) |v, i|
|
||||||
if (v != b[i]) return false;
|
if (v != b[i]) return false;
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -432,6 +431,24 @@ pub fn Tensor(
|
|||||||
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) {
|
||||||
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_);
|
||||||
|
|
||||||
|
if (comptime Self == ActualDest) return self;
|
||||||
|
|
||||||
|
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
||||||
|
const DestT = ActualDest.ValueType;
|
||||||
|
const DestVec = @Vector(total, DestT);
|
||||||
|
|
||||||
|
// If ratio is 1, just handle type conversion
|
||||||
|
if (comptime ratio == 1.0) {
|
||||||
|
if (comptime T == DestT) return .{ .data = self.data };
|
||||||
|
return .{
|
||||||
|
.data = switch (comptime @typeInfo(DestT)) {
|
||||||
|
.float => @floatFromInt(self.data), // or @floatCast
|
||||||
|
.int => @intFromFloat(self.data), // or @intCast
|
||||||
|
else => unreachable,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
if (comptime !dims.eql(ActualDest.dims))
|
if (comptime !dims.eql(ActualDest.dims))
|
||||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str());
|
||||||
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape))
|
||||||
@ -439,10 +456,6 @@ pub fn Tensor(
|
|||||||
|
|
||||||
if (comptime Self == ActualDest) return self;
|
if (comptime Self == ActualDest) return self;
|
||||||
|
|
||||||
const DestT = ActualDest.ValueType;
|
|
||||||
const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims));
|
|
||||||
const DestVec = @Vector(total, DestT);
|
|
||||||
|
|
||||||
if (comptime T == DestT) {
|
if (comptime T == DestT) {
|
||||||
if (comptime @typeInfo(T) == .float)
|
if (comptime @typeInfo(T) == .float)
|
||||||
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
|
return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) };
|
||||||
@ -485,7 +498,7 @@ pub fn Tensor(
|
|||||||
const CmpResult = if (total == 1) bool else [total]bool;
|
const CmpResult = if (total == 1) bool else [total]bool;
|
||||||
|
|
||||||
inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
|
inline fn cmpResult(v: @Vector(total, bool)) CmpResult {
|
||||||
return if (comptime total == 1) v[0] else @as([total]bool, v);
|
return if (comptime total == 1) @reduce(.And, v) else @as([total]bool, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||||
@ -634,6 +647,40 @@ pub fn Tensor(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 3D Cross Product. Only defined for Rank-1 tensors of length 3.
|
||||||
|
/// Result dimensions are the sum of input dimensions.
|
||||||
|
pub inline fn cross(self: Self, other: anytype) Tensor(
|
||||||
|
T,
|
||||||
|
dims.add(RhsT(@TypeOf(other)).dims).argsOpt(),
|
||||||
|
finerScales(Self, RhsT(@TypeOf(other))).argsOpt(),
|
||||||
|
&.{3},
|
||||||
|
) {
|
||||||
|
const rhs_q = rhs(other);
|
||||||
|
const RhsType = @TypeOf(rhs_q);
|
||||||
|
|
||||||
|
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
|
||||||
|
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring both to the same scale (e.g., mm vs m)
|
||||||
|
const p = self.resolveScalePair(rhs_q);
|
||||||
|
const l = p.l;
|
||||||
|
const r = p.r;
|
||||||
|
|
||||||
|
var res: [3]T = undefined;
|
||||||
|
if (comptime isInt(T)) {
|
||||||
|
res[0] = (l[1] *| r[2]) -| (l[2] *| r[1]);
|
||||||
|
res[1] = (l[2] *| r[0]) -| (l[0] *| r[2]);
|
||||||
|
res[2] = (l[0] *| r[1]) -| (l[1] *| r[0]);
|
||||||
|
} else {
|
||||||
|
res[0] = (l[1] * r[2]) - (l[2] * r[1]);
|
||||||
|
res[1] = (l[2] * r[0]) - (l[0] * r[2]);
|
||||||
|
res[2] = (l[0] * r[1]) - (l[1] * r[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return .{ .data = res };
|
||||||
|
}
|
||||||
|
|
||||||
/// Sum of squared elements. Cheaper than length(); use for ordering.
|
/// Sum of squared elements. Cheaper than length(); use for ordering.
|
||||||
pub inline fn lengthSqr(self: Self) T {
|
pub inline fn lengthSqr(self: Self) T {
|
||||||
return @reduce(.Add, self.data * self.data);
|
return @reduce(.Add, self.data * self.data);
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const Io = std.Io;
|
const Io = std.Io;
|
||||||
const Scalar = @import("Quantity.zig").Scalar;
|
const Tensor = @import("Tensor.zig").Tensor;
|
||||||
const Vector = @import("Quantity.zig").Vector;
|
|
||||||
|
|
||||||
var io: Io = undefined;
|
var io: Io = undefined;
|
||||||
pub fn main(init: std.process.Init) !void {
|
pub fn main(init: std.process.Init) !void {
|
||||||
@ -11,23 +10,23 @@ pub fn main(init: std.process.Init) !void {
|
|||||||
|
|
||||||
io = init.io;
|
io = init.io;
|
||||||
|
|
||||||
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
|
|
||||||
try bench_Scalar(&stdout_writer.interface);
|
try bench_Scalar(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
try bench_vsNative(&stdout_writer.interface);
|
// try bench_vsNative(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try bench_crossTypeVsNative(&stdout_writer.interface);
|
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try bench_Vector(&stdout_writer.interface);
|
try bench_Vector(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
}
|
}
|
||||||
@ -97,9 +96,9 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
|
|||||||
|
|
||||||
comptime var tidx: usize = 0;
|
comptime var tidx: usize = 0;
|
||||||
inline for (Types, TNames) |T, tname| {
|
inline for (Types, TNames) |T, tname| {
|
||||||
const M = Scalar(T, .{ .L = 1 }, .{});
|
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
|
||||||
const KM = Scalar(T, .{ .L = 1 }, .{ .L = .k });
|
const KM = Tensor(T, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
const S = Scalar(T, .{ .T = 1 }, .{});
|
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
inline for (Ops, 0..) |op_name, oidx| {
|
inline for (Ops, 0..) |op_name, oidx| {
|
||||||
var samples: [SAMPLES]f64 = undefined;
|
var samples: [SAMPLES]f64 = undefined;
|
||||||
@ -199,8 +198,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
|||||||
var native_total_ns: f64 = 0;
|
var native_total_ns: f64 = 0;
|
||||||
var quantity_total_ns: f64 = 0;
|
var quantity_total_ns: f64 = 0;
|
||||||
|
|
||||||
const M = Scalar(T, .{ .L = 1 }, .{});
|
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
|
||||||
const S = Scalar(T, .{ .T = 1 }, .{});
|
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
std.mem.doNotOptimizeAway({
|
std.mem.doNotOptimizeAway({
|
||||||
for (0..SAMPLES) |_| {
|
for (0..SAMPLES) |_| {
|
||||||
@ -321,9 +320,9 @@ fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void {
|
|||||||
var native_total_ns: f64 = 0;
|
var native_total_ns: f64 = 0;
|
||||||
var quantity_total_ns: f64 = 0;
|
var quantity_total_ns: f64 = 0;
|
||||||
|
|
||||||
const M1 = Scalar(T1, .{ .L = 1 }, .{});
|
const M1 = Tensor(T1, .{ .L = 1 }, .{}, &.{1});
|
||||||
const M2 = Scalar(T2, .{ .L = 1 }, .{});
|
const M2 = Tensor(T2, .{ .L = 1 }, .{}, &.{1});
|
||||||
const S2 = Scalar(T2, .{ .T = 1 }, .{});
|
const S2 = Tensor(T2, .{ .T = 1 }, .{}, &.{1});
|
||||||
|
|
||||||
std.mem.doNotOptimizeAway({
|
std.mem.doNotOptimizeAway({
|
||||||
for (0..SAMPLES) |_| {
|
for (0..SAMPLES) |_| {
|
||||||
@ -429,9 +428,8 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
|
|||||||
try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname });
|
try writer.print("│ {s:<16} │ {s:<4} │", .{ op_name, tname });
|
||||||
|
|
||||||
inline for (Lengths) |len| {
|
inline for (Lengths) |len| {
|
||||||
const Q_base = Scalar(T, .{ .L = 1 }, .{});
|
const Q_time = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||||
const Q_time = Scalar(T, .{ .T = 1 }, .{});
|
const V = Tensor(T, .{ .L = 1 }, .{}, &.{len});
|
||||||
const V = Vector(len, Q_base);
|
|
||||||
|
|
||||||
// cross product is only defined for len == 3
|
// cross product is only defined for len == 3
|
||||||
const is_cross = comptime std.mem.eql(u8, op_name, "cross");
|
const is_cross = comptime std.mem.eql(u8, op_name, "cross");
|
||||||
@ -455,10 +453,10 @@ fn bench_Vector(writer: *std.Io.Writer) !void {
|
|||||||
_ = v1.div(V.splat(getVal(T, i +% 2, 63)));
|
_ = v1.div(V.splat(getVal(T, i +% 2, 63)));
|
||||||
} else if (comptime std.mem.eql(u8, op_name, "mulScalar")) {
|
} else if (comptime std.mem.eql(u8, op_name, "mulScalar")) {
|
||||||
const s_val = Q_time.splat(getVal(T, i +% 2, 63));
|
const s_val = Q_time.splat(getVal(T, i +% 2, 63));
|
||||||
_ = v1.mulScalar(s_val);
|
_ = v1.mul(s_val);
|
||||||
} else if (comptime std.mem.eql(u8, op_name, "dot")) {
|
} else if (comptime std.mem.eql(u8, op_name, "dot")) {
|
||||||
const v2 = V.splat(getVal(T, i +% 5, 63));
|
const v2 = V.splat(getVal(T, i +% 5, 63));
|
||||||
_ = v1.dot(v2);
|
_ = v1.contract(v2, 0, 0);
|
||||||
} else if (comptime std.mem.eql(u8, op_name, "cross")) {
|
} else if (comptime std.mem.eql(u8, op_name, "cross")) {
|
||||||
// len == 3 guaranteed by the guard above
|
// len == 3 guaranteed by the guard above
|
||||||
const v2 = V.splat(getVal(T, i +% 5, 63));
|
const v2 = V.splat(getVal(T, i +% 5, 63));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user