From a0961e757174f7307ce7a498ca5aea70ebdd5d9f Mon Sep 17 00:00:00 2001 From: adrien Date: Mon, 27 Apr 2026 21:24:03 +0200 Subject: [PATCH] Start to optimize the shit out of it, still a long way to go After that GPPPPPUUUUUU baby! --- build.zig | 2 +- src/Base.zig | 10 ++++---- src/Tensor.zig | 45 +++++++++++++++++----------------- src/benchmark.zig | 61 ++++++++++++++++++++++------------------------- 4 files changed, 57 insertions(+), 61 deletions(-) diff --git a/build.zig b/build.zig index e507138..2883cb9 100644 --- a/build.zig +++ b/build.zig @@ -2,7 +2,7 @@ const std = @import("std"); pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); - const optimize = b.standardOptimizeOption(.{}); + const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast }); // 1. Define the module so other projects can import it _ = b.addModule("dimal", .{ diff --git a/src/Base.zig b/src/Base.zig index 4a475a7..ebc4780 100644 --- a/src/Base.zig +++ b/src/Base.zig @@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" { // Velocity = Distance / Time const v = d.div(t); try std.testing.expectEqual(25.0, v.data[0]); - try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); + try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims)); // Acceleration = Velocity / Time const a = v.div(t); try std.testing.expectEqual(12.5, a.data[0]); - try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); + try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims)); } test "BaseQuantities - Dynamics (Force and Work)" { @@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" { // Force = mass * acceleration const f = m.mul(a); try std.testing.expectEqual(98, f.data[0]); - try std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); + try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims)); // Energy (Work) = Force * distance const distance = Meter.Of(f32).splat(5.0); const energy = f.mul(distance); try std.testing.expectEqual(490, energy.data[0]); - try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); + try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims)); } test "BaseQuantities - Electric combinations" { @@ -219,7 +219,7 @@ test "BaseQuantities - Electric combinations" { // Charge = Current * time const charge = current.mul(time); try std.testing.expectEqual(6.0, charge.data[0]); - try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); + try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims)); } test "Constants - Initialization and dimension checks" { diff --git a/src/Tensor.zig b/src/Tensor.zig index 1cc3a91..563a530 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension; // Comptime utilities // ───────────────────────────────────────────────────────────────────────────── -pub fn shapeTotal(comptime shape: []const comptime_int) usize { +pub fn shapeTotal(shape: []const comptime_int) usize { var t: comptime_int = 1; for (shape) |s| t *= s; return t; } /// Check if two shapes are strictly identical. -pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool { +pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool { if (a.len != b.len) return false; for (a, 0..) |v, i| if (v != b[i]) return false; @@ -25,7 +25,7 @@ pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_i /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// e.g. shape {3, 4} → strides {4, 1} /// shape {2, 3, 4} → strides {12, 4, 1} -pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int { +pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int { var st: [shape.len]comptime_int = undefined; if (shape.len == 0) return st; st[shape.len - 1] = 1; @@ -37,7 +37,7 @@ pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_in } /// Return a copy of `shape` with the element at `axis` removed. -pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int { +pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int { var out: [shape.len - 1]comptime_int = undefined; var j: comptime_int = 0; for (shape, 0..) |v, i| { @@ -50,7 +50,7 @@ pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comp } /// Concatenate two compile-time slices. -pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int { +pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int { var out: [a.len + b.len]comptime_int = undefined; for (a, 0..) |v, i| out[i] = v; for (b, 0..) |v, i| out[a.len + i] = v; @@ -59,11 +59,7 @@ pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_i /// Decode a flat row-major index into N-D coordinates. /// Called only in comptime contexts (all arguments are comptime). -pub fn decodeFlatCoords( - comptime flat: comptime_int, - comptime n: comptime_int, - comptime strd: [n]comptime_int, -) [n]usize { +pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize { var coords: [n]comptime_int = undefined; var tmp = flat; for (0..n) |i| { @@ -75,11 +71,7 @@ pub fn decodeFlatCoords( /// Encode N-D coordinates into a flat row-major index. /// Called only in comptime contexts. -pub fn encodeFlatCoords( - comptime coords: []const usize, - comptime n: usize, - comptime strd: [n]usize, -) usize { +pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize { var flat: usize = 0; for (0..n) |i| flat += coords[i] * strd[i]; return flat; @@ -106,7 +98,7 @@ pub fn insertAxis( return out; } -fn isInt(comptime T: type) bool { +inline fn isInt(comptime T: type) bool { return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int; } @@ -130,25 +122,27 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales { else scale1); } - comptime return out; + return out; } // ───────────────────────────────────────────────────────────────────────────── // File-scope RHS normalisation helpers // ───────────────────────────────────────────────────────────────────────────── -fn isTensor(comptime Rhs: type) bool { +inline fn isTensor(comptime Rhs: type) bool { return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR"); } -fn RhsTensorType(comptime T: type, comptime Rhs: type) type { +inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type { if (comptime isTensor(Rhs)) return Rhs; return Tensor(T, .{}, .{}, &.{1}); } -fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { +/// Take the anyvalue coming from operation and if it is a Tensor, return it. +/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r). +inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { const Rhs = @TypeOf(r); if (comptime isTensor(Rhs)) return r; - const scalar: T = switch (comptime @typeInfo(Rhs)) { + const scalar: T = switch (@typeInfo(Rhs)) { .comptime_int => switch (comptime @typeInfo(T)) { .float => @as(T, @floatFromInt(r)), else => @as(T, r), @@ -278,11 +272,13 @@ pub fn Tensor( shape_, ) { const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); + const RhsType = @TypeOf(r); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); + // if (comptime total == 1 and RhsType.scales.eql(scales)) + // return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -450,6 +446,9 @@ pub fn Tensor( const DestT = ActualDest.ValueType; const DestVec = @Vector(total, DestT); + if (comptime ratio == 1.0 and T == DestT) + return .{ .data = self.data }; + // If ratio is 1, handle type conversion correctly based on BOTH source and dest types if (comptime ratio == 1.0) { const T_info = @typeInfo(T); @@ -533,7 +532,7 @@ pub fn Tensor( pub inline fn eq(self: Self, r: anytype) CmpResult { const rhs_q = rhs(r); if (comptime !dims.eql(@TypeOf(rhs_q).dims)) - @compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str()); + @compileError("Dimension mismatch in ne."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l == p.r); } diff --git a/src/benchmark.zig b/src/benchmark.zig index 6371c25..dc8bcc0 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void { io = init.io; - try vectorSIMDvsNative(f64, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(f32, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(i32, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(i64, &stdout_writer.interface); - try stdout_writer.flush(); - try vectorSIMDvsNative(i128, &stdout_writer.interface); - try stdout_writer.flush(); - - try bench_Scalar(&stdout_writer.interface); - try stdout_writer.flush(); + // try vectorSIMDvsNative(f64, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(f32, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(i32, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(i64, &stdout_writer.interface); + // try stdout_writer.flush(); + // try vectorSIMDvsNative(i128, &stdout_writer.interface); + // try stdout_writer.flush(); + // + // try bench_Scalar(&stdout_writer.interface); + // try stdout_writer.flush(); try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); - try bench_crossTypeVsNative(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_Vector(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_HighDimTensor(&stdout_writer.interface); - try stdout_writer.flush(); + // try bench_crossTypeVsNative(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_Vector(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_HighDimTensor(&stdout_writer.interface); + // try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -200,24 +200,22 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { var native_total_ns: f64 = 0; var quantity_total_ns: f64 = 0; - const M = Tensor(T, .{ .L = 1 }, .{}, &.{1}); - const S = Tensor(T, .{ .T = 1 }, .{}, &.{1}); + const M = Tensor(T, .{}, .{}, &.{1}); std.mem.doNotOptimizeAway({ for (0..SAMPLES) |_| { // --- 1. Benchmark Native --- const n_start = getTime(); - for (0..ITERS) |i| { - const a = getValT(T, i); - const b = getValT(T, 2); - + const a = getValT(T, 10); + const b = getValT(T, 2); + for (0..ITERS) |_| { // Native logic branch _ = if (comptime std.mem.eql(u8, op_name, "add")) - a + b + if (comptime @typeInfo(T) == .int) a +| b else a + b else if (comptime std.mem.eql(u8, op_name, "sub")) - a - b + if (comptime @typeInfo(T) == .int) a -| b else a - b else if (comptime std.mem.eql(u8, op_name, "mul")) - a * b + if (comptime @typeInfo(T) == .int) a *| b else a * b else if (comptime std.mem.eql(u8, op_name, "div")) if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b else if (comptime std.mem.eql(u8, op_name, "abs")) @@ -234,10 +232,9 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { // --- 2. Benchmark Scalar --- const q_start = getTime(); - for (0..ITERS) |i| { - const qa = M.splat(getValT(T, i)); - const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2)); - + const qa = M.splat(getValT(T, 10)); + const qb = M.splat(getValT(T, 2)); + for (0..ITERS) |_| { // Scalar logic branch _ = if (comptime std.mem.eql(u8, op_name, "add")) qa.add(qb)