Start to optimize the shit out of it, still a long way to go
After that GPPPPPUUUUUU baby!
This commit is contained in:
parent
168312b78e
commit
a0961e7571
@ -2,7 +2,7 @@ const std = @import("std");
|
|||||||
|
|
||||||
pub fn build(b: *std.Build) void {
|
pub fn build(b: *std.Build) void {
|
||||||
const target = b.standardTargetOptions(.{});
|
const target = b.standardTargetOptions(.{});
|
||||||
const optimize = b.standardOptimizeOption(.{});
|
const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast });
|
||||||
|
|
||||||
// 1. Define the module so other projects can import it
|
// 1. Define the module so other projects can import it
|
||||||
_ = b.addModule("dimal", .{
|
_ = b.addModule("dimal", .{
|
||||||
|
|||||||
10
src/Base.zig
10
src/Base.zig
@ -186,12 +186,12 @@ test "BaseQuantities - Kinematics equations" {
|
|||||||
// Velocity = Distance / Time
|
// Velocity = Distance / Time
|
||||||
const v = d.div(t);
|
const v = d.div(t);
|
||||||
try std.testing.expectEqual(25.0, v.data[0]);
|
try std.testing.expectEqual(25.0, v.data[0]);
|
||||||
try std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
|
try comptime std.testing.expect(Speed.dims.eql(@TypeOf(v).dims));
|
||||||
|
|
||||||
// Acceleration = Velocity / Time
|
// Acceleration = Velocity / Time
|
||||||
const a = v.div(t);
|
const a = v.div(t);
|
||||||
try std.testing.expectEqual(12.5, a.data[0]);
|
try std.testing.expectEqual(12.5, a.data[0]);
|
||||||
try std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
|
try comptime std.testing.expect(Acceleration.dims.eql(@TypeOf(a).dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "BaseQuantities - Dynamics (Force and Work)" {
|
test "BaseQuantities - Dynamics (Force and Work)" {
|
||||||
@ -203,13 +203,13 @@ test "BaseQuantities - Dynamics (Force and Work)" {
|
|||||||
// Force = mass * acceleration
|
// Force = mass * acceleration
|
||||||
const f = m.mul(a);
|
const f = m.mul(a);
|
||||||
try std.testing.expectEqual(98, f.data[0]);
|
try std.testing.expectEqual(98, f.data[0]);
|
||||||
try std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
|
try comptime std.testing.expect(Force.dims.eql(@TypeOf(f).dims));
|
||||||
|
|
||||||
// Energy (Work) = Force * distance
|
// Energy (Work) = Force * distance
|
||||||
const distance = Meter.Of(f32).splat(5.0);
|
const distance = Meter.Of(f32).splat(5.0);
|
||||||
const energy = f.mul(distance);
|
const energy = f.mul(distance);
|
||||||
try std.testing.expectEqual(490, energy.data[0]);
|
try std.testing.expectEqual(490, energy.data[0]);
|
||||||
try std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
|
try comptime std.testing.expect(Energy.dims.eql(@TypeOf(energy).dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "BaseQuantities - Electric combinations" {
|
test "BaseQuantities - Electric combinations" {
|
||||||
@ -219,7 +219,7 @@ test "BaseQuantities - Electric combinations" {
|
|||||||
// Charge = Current * time
|
// Charge = Current * time
|
||||||
const charge = current.mul(time);
|
const charge = current.mul(time);
|
||||||
try std.testing.expectEqual(6.0, charge.data[0]);
|
try std.testing.expectEqual(6.0, charge.data[0]);
|
||||||
try std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
|
try comptime std.testing.expect(ElectricCharge.dims.eql(@TypeOf(charge).dims));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "Constants - Initialization and dimension checks" {
|
test "Constants - Initialization and dimension checks" {
|
||||||
|
|||||||
@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension;
|
|||||||
// Comptime utilities
|
// Comptime utilities
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
pub fn shapeTotal(comptime shape: []const comptime_int) usize {
|
pub fn shapeTotal(shape: []const comptime_int) usize {
|
||||||
var t: comptime_int = 1;
|
var t: comptime_int = 1;
|
||||||
for (shape) |s| t *= s;
|
for (shape) |s| t *= s;
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if two shapes are strictly identical.
|
/// Check if two shapes are strictly identical.
|
||||||
pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool {
|
pub fn shapeEql(a: []const comptime_int, b: []const comptime_int) bool {
|
||||||
if (a.len != b.len) return false;
|
if (a.len != b.len) return false;
|
||||||
for (a, 0..) |v, i|
|
for (a, 0..) |v, i|
|
||||||
if (v != b[i]) return false;
|
if (v != b[i]) return false;
|
||||||
@ -25,7 +25,7 @@ pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_i
|
|||||||
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
/// Row-major (C-order) strides: strides[i] = product(shape[i+1..]).
|
||||||
/// e.g. shape {3, 4} → strides {4, 1}
|
/// e.g. shape {3, 4} → strides {4, 1}
|
||||||
/// shape {2, 3, 4} → strides {12, 4, 1}
|
/// shape {2, 3, 4} → strides {12, 4, 1}
|
||||||
pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int {
|
pub fn shapeStrides(shape: []const comptime_int) [shape.len]comptime_int {
|
||||||
var st: [shape.len]comptime_int = undefined;
|
var st: [shape.len]comptime_int = undefined;
|
||||||
if (shape.len == 0) return st;
|
if (shape.len == 0) return st;
|
||||||
st[shape.len - 1] = 1;
|
st[shape.len - 1] = 1;
|
||||||
@ -37,7 +37,7 @@ pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_in
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Return a copy of `shape` with the element at `axis` removed.
|
/// Return a copy of `shape` with the element at `axis` removed.
|
||||||
pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int {
|
pub fn shapeRemoveAxis(shape: []const comptime_int, axis: comptime_int) [shape.len - 1]comptime_int {
|
||||||
var out: [shape.len - 1]comptime_int = undefined;
|
var out: [shape.len - 1]comptime_int = undefined;
|
||||||
var j: comptime_int = 0;
|
var j: comptime_int = 0;
|
||||||
for (shape, 0..) |v, i| {
|
for (shape, 0..) |v, i| {
|
||||||
@ -50,7 +50,7 @@ pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comp
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenate two compile-time slices.
|
/// Concatenate two compile-time slices.
|
||||||
pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int {
|
pub fn shapeCat(a: []const comptime_int, b: []const comptime_int) [a.len + b.len]comptime_int {
|
||||||
var out: [a.len + b.len]comptime_int = undefined;
|
var out: [a.len + b.len]comptime_int = undefined;
|
||||||
for (a, 0..) |v, i| out[i] = v;
|
for (a, 0..) |v, i| out[i] = v;
|
||||||
for (b, 0..) |v, i| out[a.len + i] = v;
|
for (b, 0..) |v, i| out[a.len + i] = v;
|
||||||
@ -59,11 +59,7 @@ pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_i
|
|||||||
|
|
||||||
/// Decode a flat row-major index into N-D coordinates.
|
/// Decode a flat row-major index into N-D coordinates.
|
||||||
/// Called only in comptime contexts (all arguments are comptime).
|
/// Called only in comptime contexts (all arguments are comptime).
|
||||||
pub fn decodeFlatCoords(
|
pub fn decodeFlatCoords(flat: comptime_int, n: comptime_int, strd: [n]comptime_int) [n]usize {
|
||||||
comptime flat: comptime_int,
|
|
||||||
comptime n: comptime_int,
|
|
||||||
comptime strd: [n]comptime_int,
|
|
||||||
) [n]usize {
|
|
||||||
var coords: [n]comptime_int = undefined;
|
var coords: [n]comptime_int = undefined;
|
||||||
var tmp = flat;
|
var tmp = flat;
|
||||||
for (0..n) |i| {
|
for (0..n) |i| {
|
||||||
@ -75,11 +71,7 @@ pub fn decodeFlatCoords(
|
|||||||
|
|
||||||
/// Encode N-D coordinates into a flat row-major index.
|
/// Encode N-D coordinates into a flat row-major index.
|
||||||
/// Called only in comptime contexts.
|
/// Called only in comptime contexts.
|
||||||
pub fn encodeFlatCoords(
|
pub fn encodeFlatCoords(coords: []const usize, n: usize, strd: [n]usize) usize {
|
||||||
comptime coords: []const usize,
|
|
||||||
comptime n: usize,
|
|
||||||
comptime strd: [n]usize,
|
|
||||||
) usize {
|
|
||||||
var flat: usize = 0;
|
var flat: usize = 0;
|
||||||
for (0..n) |i| flat += coords[i] * strd[i];
|
for (0..n) |i| flat += coords[i] * strd[i];
|
||||||
return flat;
|
return flat;
|
||||||
@ -106,7 +98,7 @@ pub fn insertAxis(
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn isInt(comptime T: type) bool {
|
inline fn isInt(comptime T: type) bool {
|
||||||
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
return @typeInfo(T) == .int or @typeInfo(T) == .comptime_int;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -130,25 +122,27 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales {
|
|||||||
else
|
else
|
||||||
scale1);
|
scale1);
|
||||||
}
|
}
|
||||||
comptime return out;
|
return out;
|
||||||
}
|
}
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
// File-scope RHS normalisation helpers
|
// File-scope RHS normalisation helpers
|
||||||
// ─────────────────────────────────────────────────────────────────────────────
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
fn isTensor(comptime Rhs: type) bool {
|
inline fn isTensor(comptime Rhs: type) bool {
|
||||||
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
return comptime @typeInfo(Rhs) == .@"struct" and @hasDecl(Rhs, "ISTENSOR");
|
||||||
}
|
}
|
||||||
|
|
||||||
fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type {
|
||||||
if (comptime isTensor(Rhs)) return Rhs;
|
if (comptime isTensor(Rhs)) return Rhs;
|
||||||
return Tensor(T, .{}, .{}, &.{1});
|
return Tensor(T, .{}, .{}, &.{1});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
/// Take the anyvalue coming from operation and if it is a Tensor, return it.
|
||||||
|
/// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r).
|
||||||
|
inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) {
|
||||||
const Rhs = @TypeOf(r);
|
const Rhs = @TypeOf(r);
|
||||||
if (comptime isTensor(Rhs)) return r;
|
if (comptime isTensor(Rhs)) return r;
|
||||||
const scalar: T = switch (comptime @typeInfo(Rhs)) {
|
const scalar: T = switch (@typeInfo(Rhs)) {
|
||||||
.comptime_int => switch (comptime @typeInfo(T)) {
|
.comptime_int => switch (comptime @typeInfo(T)) {
|
||||||
.float => @as(T, @floatFromInt(r)),
|
.float => @as(T, @floatFromInt(r)),
|
||||||
else => @as(T, r),
|
else => @as(T, r),
|
||||||
@ -278,11 +272,13 @@ pub fn Tensor(
|
|||||||
shape_,
|
shape_,
|
||||||
) {
|
) {
|
||||||
const rhs_q = rhs(r);
|
const rhs_q = rhs(r);
|
||||||
const RhsType = @TypeOf(rhs_q);
|
const RhsType = @TypeOf(r);
|
||||||
if (comptime !dims.eql(RhsType.dims))
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape))
|
||||||
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
// if (comptime total == 1 and RhsType.scales.eql(scales))
|
||||||
|
// return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data };
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_);
|
||||||
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data;
|
||||||
@ -450,6 +446,9 @@ pub fn Tensor(
|
|||||||
const DestT = ActualDest.ValueType;
|
const DestT = ActualDest.ValueType;
|
||||||
const DestVec = @Vector(total, DestT);
|
const DestVec = @Vector(total, DestT);
|
||||||
|
|
||||||
|
if (comptime ratio == 1.0 and T == DestT)
|
||||||
|
return .{ .data = self.data };
|
||||||
|
|
||||||
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
||||||
if (comptime ratio == 1.0) {
|
if (comptime ratio == 1.0) {
|
||||||
const T_info = @typeInfo(T);
|
const T_info = @typeInfo(T);
|
||||||
@ -533,7 +532,7 @@ pub fn Tensor(
|
|||||||
pub inline fn eq(self: Self, r: anytype) CmpResult {
|
pub inline fn eq(self: Self, r: anytype) CmpResult {
|
||||||
const rhs_q = rhs(r);
|
const rhs_q = rhs(r);
|
||||||
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
if (comptime !dims.eql(@TypeOf(rhs_q).dims))
|
||||||
@compileError("Dimension mismatch in eq: " ++ dims.str() ++ " vs " ++ @TypeOf(rhs_q).dims.str());
|
@compileError("Dimension mismatch in ne.");
|
||||||
const p = resolveScalePair(self, rhs_q);
|
const p = resolveScalePair(self, rhs_q);
|
||||||
return cmpResult(p.l == p.r);
|
return cmpResult(p.l == p.r);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void {
|
|||||||
|
|
||||||
io = init.io;
|
io = init.io;
|
||||||
|
|
||||||
try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
// try vectorSIMDvsNative(f64, &stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
// try vectorSIMDvsNative(f32, &stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
// try vectorSIMDvsNative(i32, &stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
// try vectorSIMDvsNative(i64, &stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
// try vectorSIMDvsNative(i128, &stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
|
//
|
||||||
try bench_Scalar(&stdout_writer.interface);
|
// try bench_Scalar(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try bench_vsNative(&stdout_writer.interface);
|
try bench_vsNative(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
try stdout_writer.flush();
|
||||||
try bench_crossTypeVsNative(&stdout_writer.interface);
|
// try bench_crossTypeVsNative(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try bench_Vector(&stdout_writer.interface);
|
// try bench_Vector(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
try bench_HighDimTensor(&stdout_writer.interface);
|
// try bench_HighDimTensor(&stdout_writer.interface);
|
||||||
try stdout_writer.flush();
|
// try stdout_writer.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getTime() Io.Timestamp {
|
fn getTime() Io.Timestamp {
|
||||||
@ -200,24 +200,22 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
|||||||
var native_total_ns: f64 = 0;
|
var native_total_ns: f64 = 0;
|
||||||
var quantity_total_ns: f64 = 0;
|
var quantity_total_ns: f64 = 0;
|
||||||
|
|
||||||
const M = Tensor(T, .{ .L = 1 }, .{}, &.{1});
|
const M = Tensor(T, .{}, .{}, &.{1});
|
||||||
const S = Tensor(T, .{ .T = 1 }, .{}, &.{1});
|
|
||||||
|
|
||||||
std.mem.doNotOptimizeAway({
|
std.mem.doNotOptimizeAway({
|
||||||
for (0..SAMPLES) |_| {
|
for (0..SAMPLES) |_| {
|
||||||
// --- 1. Benchmark Native ---
|
// --- 1. Benchmark Native ---
|
||||||
const n_start = getTime();
|
const n_start = getTime();
|
||||||
for (0..ITERS) |i| {
|
const a = getValT(T, 10);
|
||||||
const a = getValT(T, i);
|
const b = getValT(T, 2);
|
||||||
const b = getValT(T, 2);
|
for (0..ITERS) |_| {
|
||||||
|
|
||||||
// Native logic branch
|
// Native logic branch
|
||||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||||
a + b
|
if (comptime @typeInfo(T) == .int) a +| b else a + b
|
||||||
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
else if (comptime std.mem.eql(u8, op_name, "sub"))
|
||||||
a - b
|
if (comptime @typeInfo(T) == .int) a -| b else a - b
|
||||||
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
else if (comptime std.mem.eql(u8, op_name, "mul"))
|
||||||
a * b
|
if (comptime @typeInfo(T) == .int) a *| b else a * b
|
||||||
else if (comptime std.mem.eql(u8, op_name, "div"))
|
else if (comptime std.mem.eql(u8, op_name, "div"))
|
||||||
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
|
if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b
|
||||||
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
else if (comptime std.mem.eql(u8, op_name, "abs"))
|
||||||
@ -234,10 +232,9 @@ fn bench_vsNative(writer: *std.Io.Writer) !void {
|
|||||||
|
|
||||||
// --- 2. Benchmark Scalar ---
|
// --- 2. Benchmark Scalar ---
|
||||||
const q_start = getTime();
|
const q_start = getTime();
|
||||||
for (0..ITERS) |i| {
|
const qa = M.splat(getValT(T, 10));
|
||||||
const qa = M.splat(getValT(T, i));
|
const qb = M.splat(getValT(T, 2));
|
||||||
const qb = if (comptime std.mem.eql(u8, op_name, "div")) S.splat(getValT(T, 2)) else M.splat(getValT(T, 2));
|
for (0..ITERS) |_| {
|
||||||
|
|
||||||
// Scalar logic branch
|
// Scalar logic branch
|
||||||
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
_ = if (comptime std.mem.eql(u8, op_name, "add"))
|
||||||
qa.add(qb)
|
qa.add(qb)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user