Fixed benchmark
This commit is contained in:
parent
4595397e70
commit
18830c8b45
@ -23,10 +23,10 @@ pub fn TensorStatic(
|
|||||||
comptime shape_: []const comptime_int,
|
comptime shape_: []const comptime_int,
|
||||||
) type {
|
) type {
|
||||||
comptime {
|
comptime {
|
||||||
if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
|
if (shape_.len == 0)
|
||||||
for (shape_) |s| {
|
@compileError("Tensor shape must have at least 1 dimension (rank >= 1).");
|
||||||
if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1.");
|
for (shape_) |s|
|
||||||
}
|
if (s < 1) @compileError("Tensor shape dimensions must be strictly >= 1.");
|
||||||
}
|
}
|
||||||
@setEvalBranchQuota(100_000_000);
|
@setEvalBranchQuota(100_000_000);
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ pub fn TensorStatic(
|
|||||||
@compileError("rhs can only be a Tensor ");
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
@ -127,10 +127,12 @@ pub fn TensorStatic(
|
|||||||
T,
|
T,
|
||||||
dims.add(@TypeOf(rhs).dims).argsOpt(),
|
dims.add(@TypeOf(rhs).dims).argsOpt(),
|
||||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
shape_,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
if (comptime !isTensor(RhsType))
|
||||||
|
@compileError("rhs can only be a Tensor ");
|
||||||
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
@ -146,10 +148,12 @@ pub fn TensorStatic(
|
|||||||
T,
|
T,
|
||||||
dims.sub(@TypeOf(rhs).dims).argsOpt(),
|
dims.sub(@TypeOf(rhs).dims).argsOpt(),
|
||||||
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
shape_,
|
shape,
|
||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
if (comptime !isTensor(RhsType))
|
||||||
|
@compileError("rhs can only be a Tensor ");
|
||||||
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
@ -169,7 +173,7 @@ pub fn TensorStatic(
|
|||||||
T,
|
T,
|
||||||
dims.scale(exp).argsOpt(),
|
dims.scale(exp).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
shape_,
|
shape,
|
||||||
) {
|
) {
|
||||||
if (comptime exp == 0) return .{ .data = @splat(1) };
|
if (comptime exp == 0) return .{ .data = @splat(1) };
|
||||||
if (comptime exp == 1) return self;
|
if (comptime exp == 1) return self;
|
||||||
@ -199,7 +203,7 @@ pub fn TensorStatic(
|
|||||||
T,
|
T,
|
||||||
dims.div(2).argsOpt(),
|
dims.div(2).argsOpt(),
|
||||||
scales.argsOpt(),
|
scales.argsOpt(),
|
||||||
shape_,
|
shape,
|
||||||
) {
|
) {
|
||||||
if (comptime !dims.isSquare())
|
if (comptime !dims.isSquare())
|
||||||
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
@compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even.");
|
||||||
@ -235,7 +239,7 @@ pub fn TensorStatic(
|
|||||||
// Run validation checks FIRST before dealing with types
|
// Run validation checks FIRST before dealing with types
|
||||||
if (comptime !dims.eql(Dest.dims))
|
if (comptime !dims.eql(Dest.dims))
|
||||||
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ Dest.dims.str());
|
@compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ Dest.dims.str());
|
||||||
if (comptime total != 1 and !sh.shapeEql(shape_, Dest.shape))
|
if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape))
|
||||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||||
|
|
||||||
const vec = if (comptime total == 1 and Dest.total != 1)
|
const vec = if (comptime total == 1 and Dest.total != 1)
|
||||||
@ -317,10 +321,12 @@ pub fn TensorStatic(
|
|||||||
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
/// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed.
|
||||||
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
|
inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape))
|
if (comptime !isTensor(RhsType))
|
||||||
|
@compileError("rhs can only be a Tensor ");
|
||||||
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
return .{ .l = self.to(TargetType).data, .r = rhs.to(TargetType).data };
|
return .{ .l = self.to(TargetType).data, .r = rhs.to(TargetType).data };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -381,49 +387,51 @@ pub fn TensorStatic(
|
|||||||
|
|
||||||
pub inline fn contract(
|
pub inline fn contract(
|
||||||
self: *const Self,
|
self: *const Self,
|
||||||
other: anytype,
|
rhs: anytype,
|
||||||
comptime axis_a: usize,
|
comptime axis_a: usize,
|
||||||
comptime axis_b: usize,
|
comptime axis_b: usize,
|
||||||
) blk: {
|
) blk: {
|
||||||
const OT = @TypeOf(other);
|
const RhsType = @TypeOf(rhs);
|
||||||
|
if (!isTensor(RhsType))
|
||||||
|
@compileError("rhs can only be a Tensor ");
|
||||||
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
if (axis_a >= rank) @compileError("contract: axis_a out of bounds");
|
||||||
if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds");
|
if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds");
|
||||||
if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
if (shape[axis_a] != RhsType.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes");
|
||||||
|
|
||||||
const sa = sh.shapeRemoveAxis(shape_, axis_a);
|
const sa = sh.shapeRemoveAxis(shape, axis_a);
|
||||||
const sb = sh.shapeRemoveAxis(OT.shape, axis_b);
|
const sb = sh.shapeRemoveAxis(RhsType.shape, axis_b);
|
||||||
const rs_raw = sh.shapeCat(&sa, &sb);
|
const rs_raw = sh.shapeCat(&sa, &sb);
|
||||||
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||||
break :blk TensorStatic(
|
break :blk TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.add(OT.dims).argsOpt(),
|
dims.add(RhsType.dims).argsOpt(),
|
||||||
sh.finerScales(Self, OT).argsOpt(),
|
sh.finerScales(Self, RhsType).argsOpt(),
|
||||||
rs,
|
rs,
|
||||||
);
|
);
|
||||||
} {
|
} {
|
||||||
const OT = @TypeOf(other);
|
const RhsType = @TypeOf(rhs);
|
||||||
const k: usize = comptime shape_[axis_a]; // contraction dimension
|
const k: usize = comptime shape[axis_a]; // contraction dimension
|
||||||
|
|
||||||
const sa = comptime sh.shapeRemoveAxis(shape_, axis_a);
|
const sa = comptime sh.shapeRemoveAxis(shape, axis_a);
|
||||||
const sb = comptime sh.shapeRemoveAxis(OT.shape, axis_b);
|
const sb = comptime sh.shapeRemoveAxis(RhsType.shape, axis_b);
|
||||||
const rs_raw = comptime sh.shapeCat(&sa, &sb);
|
const rs_raw = comptime sh.shapeCat(&sa, &sb);
|
||||||
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw;
|
||||||
|
|
||||||
const ResultType = TensorStatic(
|
const ResultType = TensorStatic(
|
||||||
T,
|
T,
|
||||||
dims.add(OT.dims).argsOpt(),
|
dims.add(RhsType.dims).argsOpt(),
|
||||||
sh.finerScales(Self, OT).argsOpt(),
|
sh.finerScales(Self, RhsType).argsOpt(),
|
||||||
rs,
|
rs,
|
||||||
);
|
);
|
||||||
|
|
||||||
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), shape_);
|
const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
const OtherNorm = TensorStatic(T, OT.dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), OT.shape);
|
const OtherNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape);
|
||||||
|
|
||||||
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data;
|
||||||
const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data;
|
const b_data = if (comptime RhsType == OtherNorm) rhs.data else rhs.to(OtherNorm).data;
|
||||||
|
|
||||||
// FAST PATH: Dot Product
|
// FAST PATH: Dot Product
|
||||||
if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) {
|
if (comptime rank == 1 and RhsType.rank == 1 and axis_a == 0 and axis_b == 0) {
|
||||||
if (comptime !sh.isInt(T)) {
|
if (comptime !sh.isInt(T)) {
|
||||||
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
return .{ .data = @splat(@reduce(.Add, a_data * b_data)) };
|
||||||
} else {
|
} else {
|
||||||
@ -438,13 +446,13 @@ pub fn TensorStatic(
|
|||||||
|
|
||||||
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
|
// --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING ---
|
||||||
const a_arr: [total]T = a_data;
|
const a_arr: [total]T = a_data;
|
||||||
const b_arr: [OT.total]T = b_data;
|
const b_arr: [RhsType.total]T = b_data;
|
||||||
|
|
||||||
// FAST PATH: 2D Matrix Multiplication
|
// FAST PATH: 2D Matrix Multiplication
|
||||||
if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) {
|
if (comptime rank == 2 and RhsType.rank == 2 and axis_a == 1 and axis_b == 0) {
|
||||||
const rows = shape_[0];
|
const rows = shape[0];
|
||||||
const cols = OT.shape[1];
|
const cols = RhsType.shape[1];
|
||||||
const inner = shape_[1];
|
const inner = shape[1];
|
||||||
|
|
||||||
// Create a mutable array for the result, NOT a Tensor struct
|
// Create a mutable array for the result, NOT a Tensor struct
|
||||||
var res_arr: [ResultType.total]T = undefined;
|
var res_arr: [ResultType.total]T = undefined;
|
||||||
@ -454,7 +462,7 @@ pub fn TensorStatic(
|
|||||||
var acc: T = 0;
|
var acc: T = 0;
|
||||||
for (0..inner) |id| {
|
for (0..inner) |id| {
|
||||||
const a_flat = i * _strides[0] + id * _strides[1];
|
const a_flat = i * _strides[0] + id * _strides[1];
|
||||||
const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1];
|
const b_flat = id * RhsType.strides_arr[0] + j * RhsType.strides_arr[1];
|
||||||
|
|
||||||
// Use a_arr and b_arr here
|
// Use a_arr and b_arr here
|
||||||
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||||
@ -484,9 +492,9 @@ pub fn TensorStatic(
|
|||||||
var acc: T = 0;
|
var acc: T = 0;
|
||||||
for (0..k) |ki| {
|
for (0..k) |ki| {
|
||||||
const a_coords = sh.insertAxis(rank, axis_a, ki, &a_free);
|
const a_coords = sh.insertAxis(rank, axis_a, ki, &a_free);
|
||||||
const b_coords = sh.insertAxis(OT.rank, axis_b, ki, &b_free);
|
const b_coords = sh.insertAxis(RhsType.rank, axis_b, ki, &b_free);
|
||||||
const a_flat = sh.encodeFlatCoords(&a_coords, rank, _strides);
|
const a_flat = sh.encodeFlatCoords(&a_coords, rank, _strides);
|
||||||
const b_flat = sh.encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr);
|
const b_flat = sh.encodeFlatCoords(&b_coords, RhsType.rank, RhsType.strides_arr);
|
||||||
|
|
||||||
// Use a_arr and b_arr here
|
// Use a_arr and b_arr here
|
||||||
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat];
|
||||||
@ -509,9 +517,10 @@ pub fn TensorStatic(
|
|||||||
) {
|
) {
|
||||||
const RhsType = @TypeOf(rhs);
|
const RhsType = @TypeOf(rhs);
|
||||||
|
|
||||||
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) {
|
if (!isTensor(RhsType))
|
||||||
|
@compileError("rhs can only be a Tensor ");
|
||||||
|
if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3)
|
||||||
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
@compileError("cross product is only defined for 3D vectors (rank-1, length 3)");
|
||||||
}
|
|
||||||
|
|
||||||
// Bring both to the same scale (e.g., mm vs m)
|
// Bring both to the same scale (e.g., mm vs m)
|
||||||
const p = self.resolveScalePair(rhs);
|
const p = self.resolveScalePair(rhs);
|
||||||
|
|||||||
@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void {
|
|||||||
|
|
||||||
io = init.io;
|
io = init.io;
|
||||||
|
|
||||||
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
//
|
|
||||||
// try bench_Scalar(&stdout_writer.interface);
|
try bench_Scalar(&stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
try bench_vsNative(&stdout_writer.interface);
|
try bench_vsNative(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try bench_Vector(&stdout_writer.interface);
|
try bench_Vector(&stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
// try bench_HighDimTensor(&stdout_writer.interface);
|
try bench_HighDimTensor(&stdout_writer.interface);
|
||||||
// try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getTime() Io.Timestamp {
|
fn getTime() Io.Timestamp {
|
||||||
@ -128,7 +128,7 @@ fn bench_Scalar(writer: *std.Io.Writer) !void {
|
|||||||
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
else if (comptime std.mem.eql(u8, op_name, "gt"))
|
||||||
(M.splat(getVal(T, i, 63))).gt(M.splat(getVal(T, i +% 3, 63)))
|
(M.splat(getVal(T, i, 63))).gt(M.splat(getVal(T, i +% 3, 63)))
|
||||||
else
|
else
|
||||||
(M.splat(getVal(T, i, 63))).mul(3);
|
(M.splat(getVal(T, i, 63))).mul(M.splat(3));
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user