Start to optimize the shit out of it, still a long way to go
After that GPPPPPUUUUUU baby!
This commit is contained in:
parent
168312b78e
commit
a0961e7571
@ -2,7 +2,7 @@ const std = @import("std");
|
||||
|
||||
pub fn build(b: *std.Build) void {
|
||||
const target = b.standardTargetOptions(.{});
|
||||
const optimize = b.standardOptimizeOption(.{});
|
||||
const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast });
|
||||
|
||||
// 1. Define the module so other projects can import it
|
||||
_ = b.addModule("dimal", .{
|
||||
|
||||
10
src/Base.zig
10
src/Base.zig
@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" {
|
||||
// Velocity = Distance / Time
|
||||
const v = d.div(t);
|
||||
try std.testing.expectEqual(25.0, v.data[0]);
|
||||
try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
|
||||
try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
|
||||
|
||||
// Acceleration = Velocity / Time
|
||||
const a = v.div(t);
|
||||
try std.testing.expectEqual(12.5, a.data[0]);
|
||||
try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
|
||||
try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
|
||||
}
|
||||
|
||||
test "BaseQuantities - Dynamics (Force and Work)" {
|
||||
@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
|
||||
// Force = mass * acceleration
|
||||
const f = m.mul(a);
|
||||
try std.testing.expectEqual(98, f.data[0]);
|
||||
try std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
|
||||
try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
|
||||
|
||||
// Energy (Work) = Force * distance
|
||||
const distance = Meter.Of(f32).splat(5.0);
|
||||
const energy = f.mul(distance);
|
||||
try std.testing.expectEqual(490, energy.data[0]);
|
||||
try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
|
||||
try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
|
||||
}
|
||||
|
||||
test "BaseQuantities - Electric combinations" {
|
||||
@ -219,7 +219,7 @@ test "BaseQuantities - Electric combinations" {
|
||||
// Charge = Current * time
|
||||
const charge = current.mul(time);
|
||||
try std.testing.expectEqual(6.0, charge.data[0]);
|
||||
try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
|
||||
try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
|
||||
}
|
||||
|
||||
test "Constants - Initialization and dimension checks" {
|
||||
|
||||
@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
|
||||
// Comptime utilities
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn shapeTotal(comptime shape: []const comptime_int) usize {
|
||||
pub fn shapeTotal(shape: []const comptime_int) usize {
|
||||
var t: comptime_int = 1;
|
||||
for (shape) |s| t *= s;
|
||||
return t;
|
||||
}
|
||||
|
||||
/// Check if two shapes are strictly identical.
|
||||
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
|
||||
pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool {
|
||||
if (a.len != b.len) return false;
|
||||
for (a, 0..) |v, i|
|
||||
if (v != b[i]) return false;
|
||||
@ -25,7 +25,7 @@ pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_i
|
||||
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
||||
/// e.g. shape {3, 4} → strides {4, 1}
|
||||
/// shape {2, 3, 4} → strides {12, 4, 1}
|
||||
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
|
||||
pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int {
|
||||
var st: [shape.len]comptime_int = undefined;
|
||||
if (shape.len == 0) return st;
|
||||
st[shape.len - 1] = 1;
|
||||
@ -37,7 +37,7 @@ pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_in
|
||||
}
|
||||
|
||||
/// Return a copy of `shape` with the element at `axis` removed.
|
||||
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
|
||||
pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int {
|
||||
var out: [shape.len - 1]comptime_int = undefined;
|
||||
var j: comptime_int = 0;
|
||||
for (shape, 0..) |v, i| {
|
||||
@ -50,7 +50,7 @@ pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comp
|
||||
}
|
||||
|
||||
/// Concatenate two compile-time slices.
|
||||
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
|
||||
pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int {
|
||||
var out: [a.len + b.len]comptime_int = undefined;
|
||||
for (a, 0..) |v, i| out[i] = v;
|
||||
for (b, 0..) |v, i| out[a.len + i] = v;
|
||||
@ -59,11 +59,7 @@ pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_i
|
||||
|
||||
/// Decode a flat row-major index into N-D coordinates.
|
||||
/// Called only in comptime contexts (all arguments are comptime).
|
||||
pub fn decodeFlatCoords(
|
||||
comptime flat: comptime_int,
|
||||
comptime n: comptime_int,
|
||||
comptime strd: [n]comptime_int,
|
||||
) [n]usize {
|
||||
pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize {
|
||||
var coords: [n]comptime_int = undefined;
|
||||
var tmp = flat;
|
||||
for (0..n) |i| {
|
||||
@ -75,11 +71,7 @@ pub fn decodeFlatCoords(
|
||||
|
||||
/// Encode N-D coordinates into a flat row-major index.
|
||||
/// Called only in comptime contexts.
|
||||
pub fn encodeFlatCoords(
|
||||
comptime coords: []const usize,
|
||||
comptime n: usize,
|
||||
comptime strd: [n]usize,
|
||||
) usize {
|
||||
pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize {
|
||||
var flat: usize = 0;
|
||||
for (0..n) |i| flat += coords[i] * strd[i];
|
||||
return flat;
|
||||
@ -106,7 +98,7 @@ pub fn insertAxis(
|
||||
return out;
|
||||
}
|
||||
|
||||
fn isInt(comptime T: type) bool {
|
||||
inline fn isInt(comptime T: type) bool {
|
||||
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
||||
}
|
||||
|
||||
@ -130,25 +122,27 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
||||
else
|
||||
scale1);
|
||||
}
|
||||
comptime return out;
|
||||
return out;
|
||||
}
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// File-scope RHS normalisation helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn isTensor(comptime Rhs: type) bool {
|
||||
inline fn isTensor(comptime Rhs: type) bool {
|
||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
||||
}
|
||||
|
||||
fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
||||
inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
||||
if (comptime isTensor(Rhs)) return Rhs;
|
||||
return Tensor(T, .{}, .{}, &.{1});
|
||||
}
|
||||
|
||||
fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
||||
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
||||
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
||||
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
||||
const Rhs = @TypeOf(r);
|
||||
if (comptime isTensor(Rhs)) return r;
|
||||
const scalar: T = switch (comptime @typeInfo(Rhs)) {
|
||||
const scalar: T = switch (@typeInfo(Rhs)) {
|
||||
.comptime_int => switch (comptime @typeInfo(T)) {
|
||||
.float => @as(T, @floatFromInt(r)),
|
||||
else => @as(T, r),
|
||||
@ -278,11 +272,13 @@ pub fn Tensor(
|
||||
shape_,
|
||||
) {
|
||||
const rhs_q = rhs(r);
|
||||
const RhsType = @TypeOf(rhs_q);
|
||||
const RhsType = @TypeOf(r);
|
||||
if (comptime !dims.eql(RhsType.dims))
|
||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||
// if (comptime total == 1 and RhsType.scales.eql(scales))
|
||||
// return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data };
|
||||
|
||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||
@ -450,6 +446,9 @@ pub fn Tensor(
|
||||
const DestT = ActualDest.ValueType;
|
||||
const DestVec = @Vector(total, DestT);
|
||||
|
||||
if (comptime ratio == 1.0 and T == DestT)
|
||||
return .{ .data = self.data };
|
||||
|
||||
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
||||
if (comptime ratio == 1.0) {
|
||||
const T_info = @typeInfo(T);
|
||||
@ -533,7 +532,7 @@ pub fn Tensor(
|
||||
pub inline fn eq(self: Self, r: anytype) CmpResult {
|
||||
const rhs_q = rhs(r);
|
||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||
@compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str());
|
||||
@compileError("Dimension mismatch in ne.");
|
||||
const p = resolveScalePair(self, rhs_q);
|
||||
return cmpResult(p.l == p.r);
|
||||
}
|
||||
|
||||
@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void {
|
||||
|
||||
io = init.io;
|
||||
|
||||
try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
|
||||
try bench_Scalar(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
//
|
||||
// try bench_Scalar(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
try bench_vsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_Vector(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
try bench_HighDimTensor(&stdout_writer.interface);
|
||||
try stdout_writer.flush();
|
||||
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try bench_Vector(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
// try bench_HighDimTensor(&stdout_writer.interface);
|
||||
// try stdout_writer.flush();
|
||||
}
|
||||
|
||||
fn getTime() Io.Timestamp {
|
||||
@ -200,24 +200,22 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
||||
var native_total_ns: f64 = 0;
|
||||
var quantity_total_ns: f64 = 0;
|
||||
|
||||
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
|
||||
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
||||
const M = Tensor(T, .{}, .{}, &.{1});
|
||||
|
||||
std.mem.doNotOptimizeAway({
|
||||
for (0..SAMPLES) |_| {
|
||||
// --- 1. Benchmark Native ---
|
||||
const n_start = getTime();
|
||||
for (0..ITERS) |i| {
|
||||
const a = getValT(T, i);
|
||||
const a = getValT(T, 10);
|
||||
const b = getValT(T, 2);
|
||||
|
||||
for (0..ITERS) |_| {
|
||||
// Native logic branch
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
a + b
|
||||
if (comptime @typeInfo(T) == .int) a +| b else a + b
|
||||
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
||||
a - b
|
||||
if (comptime @typeInfo(T) == .int) a -| b else a - b
|
||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||
a * b
|
||||
if (comptime @typeInfo(T) == .int) a *| b else a * b
|
||||
else if (comptime std.mem.eql(u8, op_name, "div"))
|
||||
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
|
||||
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
||||
@ -234,10 +232,9 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
||||
|
||||
// --- 2. Benchmark Scalar ---
|
||||
const q_start = getTime();
|
||||
for (0..ITERS) |i| {
|
||||
const qa = M.splat(getValT(T, i));
|
||||
const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2));
|
||||
|
||||
const qa = M.splat(getValT(T, 10));
|
||||
const qb = M.splat(getValT(T, 2));
|
||||
for (0..ITERS) |_| {
|
||||
// Scalar logic branch
|
||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||
qa.add(qb)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user