diff --git a/src/Scales.zig b/src/Scales.zig index 54484b4..e7877e1 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -118,6 +118,12 @@ pub fn set(self: *Self, key: Dimension, val: UnitScale) void { self.data.set(key, val); } +pub fn eql(self: Self, other: Self) bool { + for (self.data.values, other.data.values) |l, r| + if (l != r) return false; + return true; +} + pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; for (std.enums.values(Dimension)) |d| diff --git a/src/Tensor.zig b/src/Tensor.zig index 563a530..9b2a593 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -203,7 +203,7 @@ pub fn Tensor( if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); } } - @setEvalBranchQuota(10_000_000); + @setEvalBranchQuota(100_000); const _total: usize = comptime shapeTotal(shape_); const _strides = comptime shapeStrides(shape_); @@ -271,20 +271,20 @@ pub fn Tensor( finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); + const rhs_t = rhs(r); const RhsType = @TypeOf(r); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); - // if (comptime total == 1 and RhsType.scales.eql(scales)) - // return .{ .data = if (comptime isInt(T)) self.data +| r.data else self.data + r.data }; + if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too + return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); - const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; @@ -298,18 +298,20 @@ pub fn Tensor( finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { - const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); + const rhs_t = rhs(r); + const RhsType = @TypeOf(rhs_t); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); + if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too + return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; const rr: Vec = blk: { const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); - const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); + const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; diff --git a/src/benchmark.zig b/src/benchmark.zig index dc8bcc0..c6fb592 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -26,11 +26,11 @@ pub fn main(init: std.process.Init) !void { try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); // try bench_crossTypeVsNative(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_Vector(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_HighDimTensor(&stdout_writer.interface); - // try stdout_writer.flush(); + try stdout_writer.flush(); + try bench_Vector(&stdout_writer.interface); + try stdout_writer.flush(); + try bench_HighDimTensor(&stdout_writer.interface); + try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -171,7 +171,7 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { fn bench_vsNative(writer: *std.Io.Writer) !void { const ITERS: usize = 100_000; - const SAMPLES: usize = 5; + const SAMPLES: usize = 100; const getValT = struct { fn f(comptime TT: type, i: usize) TT { @@ -180,8 +180,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { } }.f; - const Types = .{ i32, i64, i128, f32, f64 }; - const TNames = .{ "i32", "i64", "i128", "f32", "f64" }; + const Types = .{ f64, i64, i128, f32, f64 }; + const TNames = .{ "f64", "i64", "i128", "f32", "f64" }; // Expanded Ops to match bench_Scalar const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" }; @@ -189,16 +189,17 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { \\ \\ Scalar vs Native Overhead Analysis \\ - \\┌───────────┬──────┬───────────┬───────────┬───────────┐ - \\│ Operation │ Type │ Native │ Scalar │ Slowdown │ - \\├───────────┼──────┼───────────┼───────────┼───────────┤ + \\┌───────────┬──────┬───────────┬───────────┬───────────┬───────────────────────┐ + \\│ Operation │ Type │ Native │ @Vector │ Tensor{{1}} │ Slowdown Nat | Vec │ + \\├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤ \\ , .{}); inline for (Ops, 0..) |op_name, j| { inline for (Types, 0..) |T, tidx| { var native_total_ns: f64 = 0; - var quantity_total_ns: f64 = 0; + var vector_total_ns: f64 = 0; + var tensor_total_ns: f64 = 0; const M = Tensor(T, .{}, .{}, &.{1}); @@ -230,6 +231,31 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { const n_end = getTime(); native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds())); + const v_start = getTime(); + const va = getValT(T, 10); + const vb = getValT(T, 2); + for (0..ITERS) |_| { + // Native logic branch + _ = if (comptime std.mem.eql(u8, op_name, "add")) + if (comptime @typeInfo(T) == .int) va +| vb else va + vb + else if (comptime std.mem.eql(u8, op_name, "sub")) + if (comptime @typeInfo(T) == .int) va -| vb else va - vb + else if (comptime std.mem.eql(u8, op_name, "mul")) + if (comptime @typeInfo(T) == .int) va *| vb else va * vb + else if (comptime std.mem.eql(u8, op_name, "div")) + if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb + else if (comptime std.mem.eql(u8, op_name, "abs")) + if (comptime @typeInfo(T) == .int) @abs(va) else @as(T, @abs(va)) + else if (comptime std.mem.eql(u8, op_name, "eq")) + va == vb + else if (comptime std.mem.eql(u8, op_name, "gt")) + va > vb + else + unreachable; + } + const v_end = getTime(); + vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds())); + // --- 2. Benchmark Scalar --- const q_start = getTime(); const qa = M.splat(getValT(T, 10)); @@ -254,22 +280,24 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { unreachable; } const q_end = getTime(); - quantity_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); + tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); } }); const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); - const avg_q = (quantity_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); - const slowdown = avg_q / avg_n; + const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); + const avg_t = (tensor_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); + const slowdown_nt = avg_t / avg_n; + const slowdown_vt = avg_t / avg_v; - try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x │\n", .{ - op_name, TNames[tidx], avg_n, avg_q, slowdown, + try writer.print("│ {s:<9} │ {s:<4} │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>7.2}ns │ {d:>8.2}x {d:>8.2}x │\n", .{ + op_name, TNames[tidx], avg_n, avg_v, avg_t, slowdown_nt, slowdown_vt, }); } - if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┤\n", .{}); + if (j != Ops.len - 1) try writer.print("├───────────┼──────┼───────────┼───────────┼───────────┼───────────────────────┤\n", .{}); } - try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┘\n", .{}); + try writer.print("└───────────┴──────┴───────────┴───────────┴───────────┴───────────────────────┘\n", .{}); } fn bench_crossTypeVsNative(writer: *std.Io.Writer) !void {