Fixed benchmark

This commit is contained in:
adrien 2026-05-04 22:25:18 +02:00
parent 4595397e70
commit 18830c8b45
2 changed files with 71 additions and 62 deletions

View File

@ -23,10 +23,10 @@ pub fn TensorStatic(
comptime shape_: []const comptime_int,
) type {
comptime {
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
for (shape_) |s| {
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
}
if (shape_.len == 0)
@compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
for (shape_) |s|
if (s < 1) @compileError("Tensor shape dimensions must be strictly >= 1.");
}
@setEvalBranchQuota(100_000_000);
@ -112,7 +112,7 @@ pub fn TensorStatic(
@compileError("rhs can only be a Tensor ");
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
@ -127,10 +127,12 @@ pub fn TensorStatic(
T,
dims.add(@TypeOf(rhs).dims).argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
shape_,
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
if (comptime !isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
@ -146,10 +148,12 @@ pub fn TensorStatic(
T,
dims.sub(@TypeOf(rhs).dims).argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
shape_,
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
if (comptime !isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
@ -169,7 +173,7 @@ pub fn TensorStatic(
T,
dims.scale(exp).argsOpt(),
scales.argsOpt(),
shape_,
shape,
) {
if (comptime exp == 0) return .{ .data = @splat(1) };
if (comptime exp == 1) return self;
@ -199,7 +203,7 @@ pub fn TensorStatic(
T,
dims.div(2).argsOpt(),
scales.argsOpt(),
shape_,
shape,
) {
if (comptime !dims.isSquare())
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
@ -235,7 +239,7 @@ pub fn TensorStatic(
// Run validation checks FIRST before dealing with types
if (comptime !dims.eql(Dest.dims))
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ Dest.dims.str());
if (comptime total != 1 and !sh.shapeEql(shape_, Dest.shape))
if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape))
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
const vec = if (comptime total == 1 and Dest.total != 1)
@ -317,10 +321,12 @@ pub fn TensorStatic(
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
const RhsType = @TypeOf(rhs);
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
if (comptime !isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
return .{ .l = self.to(TargetType).data, .r = rhs.to(TargetType).data };
}
@ -381,49 +387,51 @@ pub fn TensorStatic(
pub inline fn contract(
self: *const Self,
other: anytype,
rhs: anytype,
comptime axis_a: usize,
comptime axis_b: usize,
) blk: {
const OT = @TypeOf(other);
const RhsType = @TypeOf(rhs);
if (!isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
if (shape[axis_a] != RhsType.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
const sa = sh.shapeRemoveAxis(shape_, axis_a);
const sb = sh.shapeRemoveAxis(OT.shape, axis_b);
const sa = sh.shapeRemoveAxis(shape, axis_a);
const sb = sh.shapeRemoveAxis(RhsType.shape, axis_b);
const rs_raw = sh.shapeCat(&sa, &sb);
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
break :blk TensorStatic(
T,
dims.add(OT.dims).argsOpt(),
sh.finerScales(Self, OT).argsOpt(),
dims.add(RhsType.dims).argsOpt(),
sh.finerScales(Self, RhsType).argsOpt(),
rs,
);
} {
const OT = @TypeOf(other);
const k: usize = comptime shape_[axis_a]; // contraction dimension
const RhsType = @TypeOf(rhs);
const k: usize = comptime shape[axis_a]; // contraction dimension
const sa = comptime sh.shapeRemoveAxis(shape_, axis_a);
const sb = comptime sh.shapeRemoveAxis(OT.shape, axis_b);
const sa = comptime sh.shapeRemoveAxis(shape, axis_a);
const sb = comptime sh.shapeRemoveAxis(RhsType.shape, axis_b);
const rs_raw = comptime sh.shapeCat(&sa, &sb);
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
const ResultType = TensorStatic(
T,
dims.add(OT.dims).argsOpt(),
sh.finerScales(Self, OT).argsOpt(),
dims.add(RhsType.dims).argsOpt(),
sh.finerScales(Self, RhsType).argsOpt(),
rs,
);
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), shape_);
const OtherNorm = TensorStatic(T, OT.dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), OT.shape);
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
const OtherNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
const b_data = if (comptime RhsType == OtherNorm) rhs.data else rhs.to(OtherNorm).data;
// FAST PATH: Dot Product
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
if (comptime rank == 1 and RhsType.rank == 1 and axis_a == 0 and axis_b == 0) {
if (comptime !sh.isInt(T)) {
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
} else {
@ -438,13 +446,13 @@ pub fn TensorStatic(
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
const a_arr: [total]T = a_data;
const b_arr: [OT.total]T = b_data;
const b_arr: [RhsType.total]T = b_data;
// FAST PATH: 2D Matrix Multiplication
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
const rows = shape_[0];
const cols = OT.shape[1];
const inner = shape_[1];
if (comptime rank == 2 and RhsType.rank == 2 and axis_a == 1 and axis_b == 0) {
const rows = shape[0];
const cols = RhsType.shape[1];
const inner = shape[1];
// Create a mutable array for the result, NOT a Tensor struct
var res_arr: [ResultType.total]T = undefined;
@ -454,7 +462,7 @@ pub fn TensorStatic(
var acc: T = 0;
for (0..inner) |id| {
const a_flat = i * _strides[0] + id * _strides[1];
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
const b_flat = id * RhsType.strides_arr[0] + j * RhsType.strides_arr[1];
// Use a_arr and b_arr here
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
@ -484,9 +492,9 @@ pub fn TensorStatic(
var acc: T = 0;
for (0..k) |ki| {
const a_coords = sh.insertAxis(rank, axis_a, ki, &a_free);
const b_coords = sh.insertAxis(OT.rank, axis_b, ki, &b_free);
const b_coords = sh.insertAxis(RhsType.rank, axis_b, ki, &b_free);
const a_flat = sh.encodeFlatCoords(&a_coords, rank, _strides);
const b_flat = sh.encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
const b_flat = sh.encodeFlatCoords(&b_coords, RhsType.rank, RhsType.strides_arr);
// Use a_arr and b_arr here
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
@ -509,9 +517,10 @@ pub fn TensorStatic(
) {
const RhsType = @TypeOf(rhs);
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
if (!isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
}
// Bring both to the same scale (e.g., mm vs m)
const p = self.resolveScalePair(rhs);

View File

@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void {
io = init.io;
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
// try stdout_writer.flush();
//
// try bench_Scalar(&stdout_writer.interface);
// try stdout_writer.flush();
try vectorSIMDvsNative(f64, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(f32, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(i32, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(i64, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(i128, &stdout_writer.interface);
try stdout_writer.flush();
try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush();
try bench_vsNative(&stdout_writer.interface);
try stdout_writer.flush();
// try bench_crossTypeVsNative(&stdout_writer.interface);
// try stdout_writer.flush();
// try bench_Vector(&stdout_writer.interface);
// try stdout_writer.flush();
// try bench_HighDimTensor(&stdout_writer.interface);
// try stdout_writer.flush();
try bench_crossTypeVsNative(&stdout_writer.interface);
try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush();
try bench_HighDimTensor(&stdout_writer.interface);
try stdout_writer.flush();
}
fn getTime() Io.Timestamp {
@ -128,7 +128,7 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
else if (comptime std.mem.eql(u8, op_name, "gt"))
(M.splat(getVal(T, i, 63))).gt(M.splat(getVal(T, i +% 3, 63)))
else
(M.splat(getVal(T, i, 63))).mul(3);
(M.splat(getVal(T, i, 63))).mul(M.splat(3));
},
);
}