Start to optimize the shit out of it, still a long way to go

After that GPPPPPUUUUUU baby!
This commit is contained in:
adrien 2026-04-27 21:24:03 +02:00
parent 168312b78e
commit a0961e7571
4 changed files with 57 additions and 61 deletions

View File

@ -2,7 +2,7 @@ const std = @import("std");
pub fn build(b: *std.Build) void { pub fn build(b: *std.Build) void {
const target = b.standardTargetOptions(.{}); const target = b.standardTargetOptions(.{});
const optimize = b.standardOptimizeOption(.{}); const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast });
// 1. Define the module so other projects can import it // 1. Define the module so other projects can import it
_ = b.addModule("dimal", .{ _ = b.addModule("dimal", .{

View File

@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" {
// Velocity = Distance / Time // Velocity = Distance / Time
const v = d.div(t); const v = d.div(t);
try std.testing.expectEqual(25.0, v.data[0]); try std.testing.expectEqual(25.0, v.data[0]);
try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
// Acceleration = Velocity / Time // Acceleration = Velocity / Time
const a = v.div(t); const a = v.div(t);
try std.testing.expectEqual(12.5, a.data[0]); try std.testing.expectEqual(12.5, a.data[0]);
try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
} }
test "BaseQuantities - Dynamics (Force and Work)" { test "BaseQuantities - Dynamics (Force and Work)" {
@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
// Force = mass * acceleration // Force = mass * acceleration
const f = m.mul(a); const f = m.mul(a);
try std.testing.expectEqual(98, f.data[0]); try std.testing.expectEqual(98, f.data[0]);
try std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
// Energy (Work) = Force * distance // Energy (Work) = Force * distance
const distance = Meter.Of(f32).splat(5.0); const distance = Meter.Of(f32).splat(5.0);
const energy = f.mul(distance); const energy = f.mul(distance);
try std.testing.expectEqual(490, energy.data[0]); try std.testing.expectEqual(490, energy.data[0]);
try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
} }
test "BaseQuantities - Electric combinations" { test "BaseQuantities - Electric combinations" {
@ -219,7 +219,7 @@ test "BaseQuantities - Electric combinations" {
// Charge = Current * time // Charge = Current * time
const charge = current.mul(time); const charge = current.mul(time);
try std.testing.expectEqual(6.0, charge.data[0]); try std.testing.expectEqual(6.0, charge.data[0]);
try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
} }
test "Constants - Initialization and dimension checks" { test "Constants - Initialization and dimension checks" {

View File

@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
// Comptime utilities // Comptime utilities
// //
pub fn shapeTotal(comptime shape: []const comptime_int) usize { pub fn shapeTotal(shape: []const comptime_int) usize {
var t: comptime_int = 1; var t: comptime_int = 1;
for (shape) |s| t *= s; for (shape) |s| t *= s;
return t; return t;
} }
/// Check if two shapes are strictly identical. /// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool { pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool {
if (a.len != b.len) return false; if (a.len != b.len) return false;
for (a, 0..) |v, i| for (a, 0..) |v, i|
if (v != b[i]) return false; if (v != b[i]) return false;
@ -25,7 +25,7 @@ pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_i
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
/// e.g. shape {3, 4} strides {4, 1} /// e.g. shape {3, 4} strides {4, 1}
/// shape {2, 3, 4} strides {12, 4, 1} /// shape {2, 3, 4} strides {12, 4, 1}
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int { pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int {
var st: [shape.len]comptime_int = undefined; var st: [shape.len]comptime_int = undefined;
if (shape.len == 0) return st; if (shape.len == 0) return st;
st[shape.len - 1] = 1; st[shape.len - 1] = 1;
@ -37,7 +37,7 @@ pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_in
} }
/// Return a copy of `shape` with the element at `axis` removed. /// Return a copy of `shape` with the element at `axis` removed.
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int { pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int {
var out: [shape.len - 1]comptime_int = undefined; var out: [shape.len - 1]comptime_int = undefined;
var j: comptime_int = 0; var j: comptime_int = 0;
for (shape, 0..) |v, i| { for (shape, 0..) |v, i| {
@ -50,7 +50,7 @@ pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comp
} }
/// Concatenate two compile-time slices. /// Concatenate two compile-time slices.
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int { pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int {
var out: [a.len + b.len]comptime_int = undefined; var out: [a.len + b.len]comptime_int = undefined;
for (a, 0..) |v, i| out[i] = v; for (a, 0..) |v, i| out[i] = v;
for (b, 0..) |v, i| out[a.len + i] = v; for (b, 0..) |v, i| out[a.len + i] = v;
@ -59,11 +59,7 @@ pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_i
/// Decode a flat row-major index into N-D coordinates. /// Decode a flat row-major index into N-D coordinates.
/// Called only in comptime contexts (all arguments are comptime). /// Called only in comptime contexts (all arguments are comptime).
pub fn decodeFlatCoords( pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize {
comptime flat: comptime_int,
comptime n: comptime_int,
comptime strd: [n]comptime_int,
) [n]usize {
var coords: [n]comptime_int = undefined; var coords: [n]comptime_int = undefined;
var tmp = flat; var tmp = flat;
for (0..n) |i| { for (0..n) |i| {
@ -75,11 +71,7 @@ pub fn decodeFlatCoords(
/// Encode N-D coordinates into a flat row-major index. /// Encode N-D coordinates into a flat row-major index.
/// Called only in comptime contexts. /// Called only in comptime contexts.
pub fn encodeFlatCoords( pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize {
comptime coords: []const usize,
comptime n: usize,
comptime strd: [n]usize,
) usize {
var flat: usize = 0; var flat: usize = 0;
for (0..n) |i| flat += coords[i] * strd[i]; for (0..n) |i| flat += coords[i] * strd[i];
return flat; return flat;
@ -106,7 +98,7 @@ pub fn insertAxis(
return out; return out;
} }
fn isInt(comptime T: type) bool { inline fn isInt(comptime T: type) bool {
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
} }
@ -130,25 +122,27 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
else else
scale1); scale1);
} }
comptime return out; return out;
} }
// //
// File-scope RHS normalisation helpers // File-scope RHS normalisation helpers
// //
fn isTensor(comptime Rhs: type) bool { inline fn isTensor(comptime Rhs: type) bool {
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
} }
fn RhsTensorType(comptime T: type, comptime Rhs: type) type { inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
if (comptime isTensor(Rhs)) return Rhs; if (comptime isTensor(Rhs)) return Rhs;
return Tensor(T, .{}, .{}, &.{1}); return Tensor(T, .{}, .{}, &.{1});
} }
fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { /// Take the anyvalue coming from operation and if it is a Tensor, return it.
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
const Rhs = @TypeOf(r); const Rhs = @TypeOf(r);
if (comptime isTensor(Rhs)) return r; if (comptime isTensor(Rhs)) return r;
const scalar: T = switch (comptime @typeInfo(Rhs)) { const scalar: T = switch (@typeInfo(Rhs)) {
.comptime_int => switch (comptime @typeInfo(T)) { .comptime_int => switch (comptime @typeInfo(T)) {
.float => @as(T, @floatFromInt(r)), .float => @as(T, @floatFromInt(r)),
else => @as(T, r), else => @as(T, r),
@ -278,11 +272,13 @@ pub fn Tensor(
shape_, shape_,
) { ) {
const rhs_q = rhs(r); const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q); const RhsType = @TypeOf(r);
if (comptime !dims.eql(RhsType.dims)) if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
// if (comptime total == 1 and RhsType.scales.eql(scales))
// return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data };
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
@ -450,6 +446,9 @@ pub fn Tensor(
const DestT = ActualDest.ValueType; const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT); const DestVec = @Vector(total, DestT);
if (comptime ratio == 1.0 and T == DestT)
return .{ .data = self.data };
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types // If ratio is 1, handle type conversion correctly based on BOTH source and dest types
if (comptime ratio == 1.0) { if (comptime ratio == 1.0) {
const T_info = @typeInfo(T); const T_info = @typeInfo(T);
@ -533,7 +532,7 @@ pub fn Tensor(
pub inline fn eq(self: Self, r: anytype) CmpResult { pub inline fn eq(self: Self, r: anytype) CmpResult {
const rhs_q = rhs(r); const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims)) if (comptime !dims.eql(@TypeOf(rhs_q).dims))
@compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str()); @compileError("Dimension mismatch in ne.");
const p = resolveScalePair(self, rhs_q); const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l == p.r); return cmpResult(p.l == p.r);
} }

View File

@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void {
io = init.io; io = init.io;
try vectorSIMDvsNative(f64, &stdout_writer.interface); // try vectorSIMDvsNative(f64, &stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try vectorSIMDvsNative(f32, &stdout_writer.interface); // try vectorSIMDvsNative(f32, &stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try vectorSIMDvsNative(i32, &stdout_writer.interface); // try vectorSIMDvsNative(i32, &stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try vectorSIMDvsNative(i64, &stdout_writer.interface); // try vectorSIMDvsNative(i64, &stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try vectorSIMDvsNative(i128, &stdout_writer.interface); // try vectorSIMDvsNative(i128, &stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
//
try bench_Scalar(&stdout_writer.interface); // try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_vsNative(&stdout_writer.interface); try bench_vsNative(&stdout_writer.interface);
try stdout_writer.flush(); try stdout_writer.flush();
try bench_crossTypeVsNative(&stdout_writer.interface); // try bench_crossTypeVsNative(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface); // try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
try bench_HighDimTensor(&stdout_writer.interface); // try bench_HighDimTensor(&stdout_writer.interface);
try stdout_writer.flush(); // try stdout_writer.flush();
} }
fn getTime() Io.Timestamp { fn getTime() Io.Timestamp {
@ -200,24 +200,22 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
var native_total_ns: f64 = 0; var native_total_ns: f64 = 0;
var quantity_total_ns: f64 = 0; var quantity_total_ns: f64 = 0;
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); const M = Tensor(T, .{}, .{}, &.{1});
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
std.mem.doNotOptimizeAway({ std.mem.doNotOptimizeAway({
for (0..SAMPLES) |_| { for (0..SAMPLES) |_| {
// --- 1. Benchmark Native --- // --- 1. Benchmark Native ---
const n_start = getTime(); const n_start = getTime();
for (0..ITERS) |i| { const a = getValT(T, 10);
const a = getValT(T, i);
const b = getValT(T, 2); const b = getValT(T, 2);
for (0..ITERS) |_| {
// Native logic branch // Native logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add")) _ = if (comptime std.mem.eql(u8, op_name, "add"))
a + b if (comptime @typeInfo(T) == .int) a +| b else a + b
else if (comptime std.mem.eql(u8, op_name, "sub")) else if (comptime std.mem.eql(u8, op_name, "sub"))
a - b if (comptime @typeInfo(T) == .int) a -| b else a - b
else if (comptime std.mem.eql(u8, op_name, "mul")) else if (comptime std.mem.eql(u8, op_name, "mul"))
a * b if (comptime @typeInfo(T) == .int) a *| b else a * b
else if (comptime std.mem.eql(u8, op_name, "div")) else if (comptime std.mem.eql(u8, op_name, "div"))
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
else if (comptime std.mem.eql(u8, op_name, "abs")) else if (comptime std.mem.eql(u8, op_name, "abs"))
@ -234,10 +232,9 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
// --- 2. Benchmark Scalar --- // --- 2. Benchmark Scalar ---
const q_start = getTime(); const q_start = getTime();
for (0..ITERS) |i| { const qa = M.splat(getValT(T, 10));
const qa = M.splat(getValT(T, i)); const qb = M.splat(getValT(T, 2));
const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2)); for (0..ITERS) |_| {
// Scalar logic branch // Scalar logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add")) _ = if (comptime std.mem.eql(u8, op_name, "add"))
qa.add(qb) qa.add(qb)