Start to optimize the shit out of it, still a long way to go

After that GPPPPPUUUUUU baby!
This commit is contained in:
adrien 2026-04-27 21:24:03 +02:00
parent 168312b78e
commit a0961e7571
4 changed files with 57 additions and 61 deletions

View File

@ -2,7 +2,7 @@ const std = @import("std");
pub fn build(b: *std.Build) void {
const target = b.standardTargetOptions(.{});
const optimize = b.standardOptimizeOption(.{});
const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast });
// 1. Define the module so other projects can import it
_ = b.addModule("dimal", .{

View File

@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" {
// Velocity = Distance / Time
const v = d.div(t);
try std.testing.expectEqual(25.0, v.data[0]);
try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
// Acceleration = Velocity / Time
const a = v.div(t);
try std.testing.expectEqual(12.5, a.data[0]);
try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
}
test "BaseQuantities - Dynamics (Force and Work)" {
@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
// Force = mass * acceleration
const f = m.mul(a);
try std.testing.expectEqual(98, f.data[0]);
try std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
// Energy (Work) = Force * distance
const distance = Meter.Of(f32).splat(5.0);
const energy = f.mul(distance);
try std.testing.expectEqual(490, energy.data[0]);
try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
}
test "BaseQuantities - Electric combinations" {
@ -219,7 +219,7 @@ test "BaseQuantities - Electric combinations" {
// Charge = Current * time
const charge = current.mul(time);
try std.testing.expectEqual(6.0, charge.data[0]);
try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
}
test "Constants - Initialization and dimension checks" {

View File

@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
// Comptime utilities
//
pub fn shapeTotal(comptime shape: []const comptime_int) usize {
pub fn shapeTotal(shape: []const comptime_int) usize {
var t: comptime_int = 1;
for (shape) |s| t *= s;
return t;
}
/// Check if two shapes are strictly identical.
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool {
if (a.len != b.len) return false;
for (a, 0..) |v, i|
if (v != b[i]) return false;
@ -25,7 +25,7 @@ pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_i
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
/// e.g. shape {3, 4} strides {4, 1}
/// shape {2, 3, 4} strides {12, 4, 1}
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int {
var st: [shape.len]comptime_int = undefined;
if (shape.len == 0) return st;
st[shape.len - 1] = 1;
@ -37,7 +37,7 @@ pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_in
}
/// Return a copy of `shape` with the element at `axis` removed.
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int {
var out: [shape.len - 1]comptime_int = undefined;
var j: comptime_int = 0;
for (shape, 0..) |v, i| {
@ -50,7 +50,7 @@ pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comp
}
/// Concatenate two compile-time slices.
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int {
var out: [a.len + b.len]comptime_int = undefined;
for (a, 0..) |v, i| out[i] = v;
for (b, 0..) |v, i| out[a.len + i] = v;
@ -59,11 +59,7 @@ pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_i
/// Decode a flat row-major index into N-D coordinates.
/// Called only in comptime contexts (all arguments are comptime).
pub fn decodeFlatCoords(
comptime flat: comptime_int,
comptime n: comptime_int,
comptime strd: [n]comptime_int,
) [n]usize {
pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize {
var coords: [n]comptime_int = undefined;
var tmp = flat;
for (0..n) |i| {
@ -75,11 +71,7 @@ pub fn decodeFlatCoords(
/// Encode N-D coordinates into a flat row-major index.
/// Called only in comptime contexts.
pub fn encodeFlatCoords(
comptime coords: []const usize,
comptime n: usize,
comptime strd: [n]usize,
) usize {
pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize {
var flat: usize = 0;
for (0..n) |i| flat += coords[i] * strd[i];
return flat;
@ -106,7 +98,7 @@ pub fn insertAxis(
return out;
}
fn isInt(comptime T: type) bool {
inline fn isInt(comptime T: type) bool {
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
}
@ -130,25 +122,27 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
else
scale1);
}
comptime return out;
return out;
}
//
// File-scope RHS normalisation helpers
//
fn isTensor(comptime Rhs: type) bool {
inline fn isTensor(comptime Rhs: type) bool {
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
}
fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
if (comptime isTensor(Rhs)) return Rhs;
return Tensor(T, .{}, .{}, &.{1});
}
fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
const Rhs = @TypeOf(r);
if (comptime isTensor(Rhs)) return r;
const scalar: T = switch (comptime @typeInfo(Rhs)) {
const scalar: T = switch (@typeInfo(Rhs)) {
.comptime_int => switch (comptime @typeInfo(T)) {
.float => @as(T, @floatFromInt(r)),
else => @as(T, r),
@ -278,11 +272,13 @@ pub fn Tensor(
shape_,
) {
const rhs_q = rhs(r);
const RhsType = @TypeOf(rhs_q);
const RhsType = @TypeOf(r);
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
// if (comptime total == 1 and RhsType.scales.eql(scales))
// return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data };
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
@ -450,6 +446,9 @@ pub fn Tensor(
const DestT = ActualDest.ValueType;
const DestVec = @Vector(total, DestT);
if (comptime ratio == 1.0 and T == DestT)
return .{ .data = self.data };
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
if (comptime ratio == 1.0) {
const T_info = @typeInfo(T);
@ -533,7 +532,7 @@ pub fn Tensor(
pub inline fn eq(self: Self, r: anytype) CmpResult {
const rhs_q = rhs(r);
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
@compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str());
@compileError("Dimension mismatch in ne.");
const p = resolveScalePair(self, rhs_q);
return cmpResult(p.l == p.r);
}

View File

@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void {
io = init.io;
try vectorSIMDvsNative(f64, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(f32, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(i32, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(i64, &stdout_writer.interface);
try stdout_writer.flush();
try vectorSIMDvsNative(i128, &stdout_writer.interface);
try stdout_writer.flush();
try bench_Scalar(&stdout_writer.interface);
try stdout_writer.flush();
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
// try stdout_writer.flush();
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
// try stdout_writer.flush();
//
// try bench_Scalar(&stdout_writer.interface);
// try stdout_writer.flush();
try bench_vsNative(&stdout_writer.interface);
try stdout_writer.flush();
try bench_crossTypeVsNative(&stdout_writer.interface);
try stdout_writer.flush();
try bench_Vector(&stdout_writer.interface);
try stdout_writer.flush();
try bench_HighDimTensor(&stdout_writer.interface);
try stdout_writer.flush();
// try bench_crossTypeVsNative(&stdout_writer.interface);
// try stdout_writer.flush();
// try bench_Vector(&stdout_writer.interface);
// try stdout_writer.flush();
// try bench_HighDimTensor(&stdout_writer.interface);
// try stdout_writer.flush();
}
fn getTime() Io.Timestamp {
@ -200,24 +200,22 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
var native_total_ns: f64 = 0;
var quantity_total_ns: f64 = 0;
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
const M = Tensor(T, .{}, .{}, &.{1});
std.mem.doNotOptimizeAway({
for (0..SAMPLES) |_| {
// --- 1. Benchmark Native ---
const n_start = getTime();
for (0..ITERS) |i| {
const a = getValT(T, i);
const a = getValT(T, 10);
const b = getValT(T, 2);
for (0..ITERS) |_| {
// Native logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add"))
a + b
if (comptime @typeInfo(T) == .int) a +| b else a + b
else if (comptime std.mem.eql(u8, op_name, "sub"))
a - b
if (comptime @typeInfo(T) == .int) a -| b else a - b
else if (comptime std.mem.eql(u8, op_name, "mul"))
a * b
if (comptime @typeInfo(T) == .int) a *| b else a * b
else if (comptime std.mem.eql(u8, op_name, "div"))
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
else if (comptime std.mem.eql(u8, op_name, "abs"))
@ -234,10 +232,9 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
// --- 2. Benchmark Scalar ---
const q_start = getTime();
for (0..ITERS) |i| {
const qa = M.splat(getValT(T, i));
const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2));
const qa = M.splat(getValT(T, 10));
const qb = M.splat(getValT(T, 2));
for (0..ITERS) |_| {
// Scalar logic branch
_ = if (comptime std.mem.eql(u8, op_name, "add"))
qa.add(qb)