From 26ff02c50fdaf1f796bc6c76c92cba891217b8f4 Mon Sep 17 00:00:00 2001 From: adrien Date: Tue, 28 Apr 2026 13:10:14 +0200 Subject: [PATCH] Changed TEnsor to use *const --- src/Tensor.zig | 143 ++++++++++++++++++++------------------ src/benchmark.zig | 170 +++++++++++++++++++++++----------------------- 2 files changed, 160 insertions(+), 153 deletions(-) diff --git a/src/Tensor.zig b/src/Tensor.zig index dc67199..2484dac 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -139,8 +139,8 @@ inline fn RhsTensorType(comptime T: type, comptime Rhs: type) type { /// Take the anyvalue coming from operation and if it is a Tensor, return it. /// If it is a float or int, return a Tensor(T, .{}, .{}, .{1}).splat(r). -inline fn toRhsTensor(comptime T: type, r: anytype) RhsTensorType(T, @TypeOf(r)) { - const Rhs = @TypeOf(r); +inline fn toRhsTensor(comptime T: type, r: anytype) *const RhsTensorType(T, @TypeOf(r.*)) { + const Rhs = @TypeOf(r.*); if (comptime isTensor(Rhs)) return r; const scalar: T = switch (@typeInfo(Rhs)) { .comptime_int => switch (comptime @typeInfo(T)) { @@ -252,7 +252,7 @@ pub fn Tensor( inline fn RhsT(comptime Rhs: type) type { return RhsTensorType(T, Rhs); } - inline fn rhs(r: anytype) RhsT(@TypeOf(r)) { + inline fn rhs(r: anytype) *const RhsT(@TypeOf(r.*)) { return toRhsTensor(T, r); } @@ -265,20 +265,20 @@ pub fn Tensor( /// Element-wise add. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). - pub inline fn add(self: Self, r: anytype) Tensor( + pub inline fn add(self: *const Self, r: anytype) *const Tensor( T, dims.argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_t = rhs(r); - const RhsType = @TypeOf(r); + const RhsType = @TypeOf(rhs_t.*); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too - return .{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; + return &.{ .data = if (comptime isInt(T)) self.data +| rhs_t.data else self.data + rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -287,25 +287,25 @@ pub fn Tensor( const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return .{ .data = if (comptime isInt(T)) l +| rr else l + rr }; + return &.{ .data = if (comptime isInt(T)) l +| rr else l + rr }; } /// Element-wise sub. Dimensions must match; scales resolve to finer. /// RHS must have the same shape as self, or total == 1 (broadcast). - pub inline fn sub(self: Self, r: anytype) Tensor( + pub inline fn sub(self: *const Self, r: anytype) *const Tensor( T, dims.argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_t = rhs(r); - const RhsType = @TypeOf(rhs_t); + const RhsType = @TypeOf(rhs_t.*); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in sub: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in sub: element-wise operations require identical shapes, or a scalar RHS."); if (comptime total == 1 and scales.eql(RhsType.scales)) // Here rhs_t has to be {1} too - return .{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; + return &.{ .data = if (comptime isInt(T)) self.data -| rhs_t.data else self.data - rhs_t.data }; const TargetType = Tensor(T, dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), shape_); const l: Vec = if (comptime Self == TargetType) self.data else self.to(TargetType).data; @@ -314,19 +314,19 @@ pub fn Tensor( const rn = if (comptime RhsType == RhsNorm) rhs_t else rhs_t.to(RhsNorm); break :blk broadcastToVec(RhsNorm, rn); }; - return .{ .data = if (comptime isInt(T)) l -| rr else l - rr }; + return &.{ .data = if (comptime isInt(T)) l -| rr else l - rr }; } /// Element-wise multiply. Dimension exponents summed. /// Shape {1} RHS is automatically broadcast across all elements. - pub inline fn mul(self: Self, r: anytype) Tensor( + pub inline fn mul(self: *const Self, r: anytype) *const Tensor( T, dims.add(RhsT(@TypeOf(r)).dims).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); + const RhsType = @TypeOf(rhs_q.*); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); @@ -334,20 +334,20 @@ pub fn Tensor( const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); - return .{ .data = if (comptime isInt(T)) l *| rr else l * rr }; + const rr: Vec = broadcastToVec(RhsNorm, rr_base.*); + return &.{ .data = if (comptime isInt(T)) l *| rr else l * rr }; } /// Element-wise divide. Dimension exponents subtracted. /// Shape {1} RHS is automatically broadcast across all elements. - pub inline fn div(self: Self, r: anytype) Tensor( + pub inline fn div(self: *const Self, r: anytype) *const Tensor( T, dims.sub(RhsT(@TypeOf(r)).dims).argsOpt(), finerScales(Self, RhsT(@TypeOf(r))).argsOpt(), shape_, ) { const rhs_q = rhs(r); - const RhsType = @TypeOf(rhs_q); + const RhsType = @TypeOf(rhs_q.*); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); @@ -355,17 +355,17 @@ pub fn Tensor( const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const l: Vec = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - const rr: Vec = broadcastToVec(RhsNorm, rr_base); + const rr: Vec = broadcastToVec(RhsNorm, rr_base.*); if (comptime isInt(T)) { - return .{ .data = @divTrunc(l, rr) }; + return &.{ .data = @divTrunc(l, rr) }; } else { - return .{ .data = l / rr }; + return &.{ .data = l / rr }; } } /// Absolute value of every element. - pub inline fn abs(self: Self) Self { - return .{ .data = @bitCast(@abs(self.data)) }; + pub inline fn abs(self: *const Self) *const Self { + return &.{ .data = @bitCast(@abs(self.data)) }; } /// Raise every element to a comptime integer exponent. @@ -375,7 +375,7 @@ pub fn Tensor( scales.argsOpt(), shape_, ) { - if (comptime exp == 0) return .{ .data = @splat(1) }; + if (comptime exp == 0) return &.{ .data = @splat(1) }; if (comptime exp == 1) return self; var base = self.data; @@ -395,7 +395,7 @@ pub fn Tensor( if (comptime !isInt(T) and exp < 0) { result = @as(Vec, @splat(1)) / result; } - return .{ .data = result }; + return &.{ .data = result }; } /// Square root of every element. All dimension exponents must be even. @@ -408,7 +408,7 @@ pub fn Tensor( if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); if (comptime @typeInfo(T) == .float) { - return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! + return &.{ .data = @sqrt(self.data) }; // Float is natively vectorized! } else { const arr: [total]T = self.data; // Add this! var res_arr: [total]T = undefined; @@ -417,13 +417,13 @@ pub fn Tensor( const v = arr[i]; res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } - return .{ .data = res_arr }; + return &.{ .data = res_arr }; } } /// Negate every element. pub inline fn negate(self: Self) Self { - return .{ .data = -self.data }; + return &.{ .data = -self.data }; } /// Convert to a compatible Tensor type. @@ -431,9 +431,9 @@ pub fn Tensor( /// • Dest.shape must equal self.shape, or Dest.total == 1 (scalar pattern). /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. pub inline fn to( - self: Self, + self: *const Self, comptime Dest: type, - ) Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { + ) *const Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_) { const ActualDest = Tensor(Dest.ValueType, Dest.dims.argsOpt(), Dest.scales.argsOpt(), shape_); if (comptime Self == ActualDest) return self; @@ -449,14 +449,14 @@ pub fn Tensor( const DestVec = @Vector(total, DestT); if (comptime ratio == 1.0 and T == DestT) - return .{ .data = self.data }; + return &.{ .data = self.data }; // If ratio is 1, handle type conversion correctly based on BOTH source and dest types if (comptime ratio == 1.0) { const T_info = @typeInfo(T); const Dest_info = @typeInfo(DestT); - return .{ + return &.{ .data = if (comptime T_info == .int and Dest_info == .int) @as(DestVec, @intCast(self.data)) else if (comptime T_info == .float and Dest_info == .float) @@ -472,22 +472,22 @@ pub fn Tensor( if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) - return .{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; + return &.{ .data = self.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))) }; if (comptime ratio >= 1.0) { const mult: T = comptime @intFromFloat(@round(ratio)); - return .{ .data = self.data *| @as(Vec, @splat(mult)) }; + return &.{ .data = self.data *| @as(Vec, @splat(mult)) }; } else { const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); const half: T = comptime @divTrunc(div_val, 2); if (comptime @typeInfo(T).int.signedness == .unsigned) { - return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) }; + return &.{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) }; } else { // Vectorized branchless negative handling const is_pos = self.data >= @as(Vec, @splat(0)); const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half))); - return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) }; + return &.{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) }; } } } @@ -503,8 +503,8 @@ pub fn Tensor( const scaled = float_vec * @as(FVec, @splat(ratio)); return switch (comptime @typeInfo(DestT)) { - .float => .{ .data = @floatCast(scaled) }, - .int => .{ .data = @intFromFloat(@round(scaled)) }, + .float => &.{ .data = @floatCast(scaled) }, + .int => &.{ .data = @intFromFloat(@round(scaled)) }, else => unreachable, }; } @@ -516,8 +516,8 @@ pub fn Tensor( } /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. - inline fn resolveScalePair(self: Self, rhs_q: anytype) struct { l: Vec, r: Vec } { - const RhsType = @TypeOf(rhs_q); + inline fn resolveScalePair(self: *const Self, rhs_q: anytype) struct { l: Vec, r: Vec } { + const RhsType = @TypeOf(rhs_q.*); if (comptime RhsType.total != 1 and !shapeEql(shape_, RhsType.shape)) @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); @@ -526,74 +526,74 @@ pub fn Tensor( const rr: Vec = blk: { const RhsNorm = Tensor(T, RhsType.dims.argsOpt(), finerScales(Self, RhsType).argsOpt(), RhsType.shape); const rn = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); - break :blk broadcastToVec(RhsNorm, rn); + break :blk broadcastToVec(RhsNorm, rn.*); }; return .{ .l = l, .r = rr }; } - pub inline fn eq(self: Self, r: anytype) CmpResult { + pub inline fn eq(self: *const Self, r: anytype) CmpResult { const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + if (comptime !dims.eql(@TypeOf(rhs_q.*).dims)) @compileError("Dimension mismatch in ne."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l == p.r); } - pub inline fn ne(self: Self, r: anytype) CmpResult { + pub inline fn ne(self: *const Self, r: anytype) CmpResult { const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + if (comptime !dims.eql(@TypeOf(rhs_q.*).dims)) @compileError("Dimension mismatch in ne."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l != p.r); } - pub inline fn gt(self: Self, r: anytype) CmpResult { + pub inline fn gt(self: *const Self, r: anytype) CmpResult { const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + if (comptime !dims.eql(@TypeOf(rhs_q.*).dims)) @compileError("Dimension mismatch in gt."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l > p.r); } - pub inline fn gte(self: Self, r: anytype) CmpResult { + pub inline fn gte(self: *const Self, r: anytype) CmpResult { const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + if (comptime !dims.eql(@TypeOf(rhs_q.*).dims)) @compileError("Dimension mismatch in gte."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l >= p.r); } - pub inline fn lt(self: Self, r: anytype) CmpResult { + pub inline fn lt(self: *const Self, r: anytype) CmpResult { const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + if (comptime !dims.eql(@TypeOf(rhs_q.*).dims)) @compileError("Dimension mismatch in lt."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l < p.r); } - pub inline fn lte(self: Self, r: anytype) CmpResult { + pub inline fn lte(self: *const Self, r: anytype) CmpResult { const rhs_q = rhs(r); - if (comptime !dims.eql(@TypeOf(rhs_q).dims)) + if (comptime !dims.eql(@TypeOf(rhs_q.*).dims)) @compileError("Dimension mismatch in lte."); const p = resolveScalePair(self, rhs_q); return cmpResult(p.l <= p.r); } /// True iff every element is equal after scale resolution. - pub inline fn eqAll(self: Self, other: anytype) bool { - if (comptime !dims.eql(@TypeOf(other).dims)) + pub inline fn eqAll(self: *const Self, other: anytype) bool { + if (comptime !dims.eql(@TypeOf(other.*).dims)) @compileError("Dimension mismatch in eqAll."); const p = resolveScalePair(self, other); return @reduce(.And, p.l == p.r); } /// True iff any element differs after scale resolution. - pub inline fn neAll(self: Self, other: anytype) bool { + pub inline fn neAll(self: *const Self, other: anytype) bool { return !self.eqAll(other); } pub inline fn contract( - self: Self, + self: *const Self, other: anytype, comptime axis_a: usize, comptime axis_b: usize, @@ -607,7 +607,7 @@ pub fn Tensor( const sb = shapeRemoveAxis(OT.shape, axis_b); const rs_raw = shapeCat(&sa, &sb); const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; - break :blk Tensor( + break :blk *const Tensor( T, dims.add(OT.dims).argsOpt(), finerScales(Self, OT).argsOpt(), @@ -645,7 +645,7 @@ pub fn Tensor( const mul_arr: [total]T = a_data *| b_data; var acc: T = 0; for (mul_arr) |val| acc +|= val; - return .{ .data = @splat(acc) }; + return &.{ .data = @splat(acc) }; } } @@ -677,7 +677,7 @@ pub fn Tensor( } } // Return the initialized Tensor struct - return .{ .data = res_arr }; + return &.{ .data = res_arr }; } // FALLBACK PATH @@ -709,12 +709,12 @@ pub fn Tensor( } // Return the initialized Tensor struct - return .{ .data = result_arr }; + return &.{ .data = result_arr }; } /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. /// Result dimensions are the sum of input dimensions. - pub inline fn cross(self: Self, other: anytype) Tensor( + pub inline fn cross(self: *const Self, other: anytype) *const Tensor( T, dims.add(RhsT(@TypeOf(other)).dims).argsOpt(), finerScales(Self, RhsT(@TypeOf(other))).argsOpt(), @@ -743,16 +743,16 @@ pub fn Tensor( res[2] = (l[0] * r[1]) - (l[1] * r[0]); } - return .{ .data = res }; + return &.{ .data = res }; } /// Sum of squared elements. Cheaper than length(); use for ordering. - pub inline fn lengthSqr(self: Self) T { + pub inline fn lengthSqr(self: *const Self) T { return @reduce(.Add, self.data * self.data); } /// Euclidean length (L2 norm). - pub inline fn length(self: Self) T { + pub inline fn length(self: *const Self) T { const sq = self.lengthSqr(); if (comptime @typeInfo(T) == .int) { const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); @@ -762,17 +762,17 @@ pub fn Tensor( } /// Product of all elements. Result has shape {1}; dimension exponent * total. - pub inline fn product(self: Self) Tensor( + pub inline fn product(self: *const Self) *const Tensor( T, dims.scale(@as(comptime_int, total)).argsOpt(), scales.argsOpt(), &.{1}, ) { - return .{ .data = .{@reduce(.Mul, self.data)} }; + return &.{ .data = .{@reduce(.Mul, self.data)} }; } pub fn formatNumber( - self: Self, + self: *const Self, writer: *std.Io.Writer, options: std.fmt.Number, ) !void { @@ -1393,3 +1393,12 @@ test "Tensor strides_arr correctness" { try std.testing.expectEqual(4, T3.strides_arr[1]); try std.testing.expectEqual(1, T3.strides_arr[2]); } + +test "Big Tensor" { + const Tens = Tensor(f32, .{}, .{}, &.{1_000_000}); + const t1 = Tens.splat(2); + const t2 = Tens.splat(3); + const t3 = t1.add(t2); + + try std.testing.expectApproxEqAbs(5, t3.data[0], 0.0001); +} diff --git a/src/benchmark.zig b/src/benchmark.zig index c6fb592..080b96e 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -26,11 +26,11 @@ pub fn main(init: std.process.Init) !void { try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); // try bench_crossTypeVsNative(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_Vector(&stdout_writer.interface); - try stdout_writer.flush(); - try bench_HighDimTensor(&stdout_writer.interface); - try stdout_writer.flush(); + // try stdout_writer.flush(); + // try bench_Vector(&stdout_writer.interface); + // try stdout_writer.flush(); + // try bench_HighDimTensor(&stdout_writer.interface); + // try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -180,8 +180,8 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { } }.f; - const Types = .{ f64, i64, i128, f32, f64 }; - const TNames = .{ "f64", "i64", "i128", "f32", "f64" }; + const Types = .{ i32, i64, i128, f32, f64 }; + const TNames = .{ "i32", "i64", "i128", "f32", "f64" }; // Expanded Ops to match bench_Scalar const Ops = .{ "add", "sub", "mul", "div", "abs", "eq", "gt" }; @@ -203,86 +203,84 @@ fn bench_vsNative(writer: *std.Io.Writer) !void { const M = Tensor(T, .{}, .{}, &.{1}); - std.mem.doNotOptimizeAway({ - for (0..SAMPLES) |_| { - // --- 1. Benchmark Native --- - const n_start = getTime(); - const a = getValT(T, 10); - const b = getValT(T, 2); - for (0..ITERS) |_| { - // Native logic branch - _ = if (comptime std.mem.eql(u8, op_name, "add")) - if (comptime @typeInfo(T) == .int) a +| b else a + b - else if (comptime std.mem.eql(u8, op_name, "sub")) - if (comptime @typeInfo(T) == .int) a -| b else a - b - else if (comptime std.mem.eql(u8, op_name, "mul")) - if (comptime @typeInfo(T) == .int) a *| b else a * b - else if (comptime std.mem.eql(u8, op_name, "div")) - if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b - else if (comptime std.mem.eql(u8, op_name, "abs")) - if (comptime @typeInfo(T) == .int) @abs(a) else @as(T, @abs(a)) - else if (comptime std.mem.eql(u8, op_name, "eq")) - a == b - else if (comptime std.mem.eql(u8, op_name, "gt")) - a > b - else - unreachable; - } - const n_end = getTime(); - native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds())); - - const v_start = getTime(); - const va = getValT(T, 10); - const vb = getValT(T, 2); - for (0..ITERS) |_| { - // Native logic branch - _ = if (comptime std.mem.eql(u8, op_name, "add")) - if (comptime @typeInfo(T) == .int) va +| vb else va + vb - else if (comptime std.mem.eql(u8, op_name, "sub")) - if (comptime @typeInfo(T) == .int) va -| vb else va - vb - else if (comptime std.mem.eql(u8, op_name, "mul")) - if (comptime @typeInfo(T) == .int) va *| vb else va * vb - else if (comptime std.mem.eql(u8, op_name, "div")) - if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb - else if (comptime std.mem.eql(u8, op_name, "abs")) - if (comptime @typeInfo(T) == .int) @abs(va) else @as(T, @abs(va)) - else if (comptime std.mem.eql(u8, op_name, "eq")) - va == vb - else if (comptime std.mem.eql(u8, op_name, "gt")) - va > vb - else - unreachable; - } - const v_end = getTime(); - vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds())); - - // --- 2. Benchmark Scalar --- - const q_start = getTime(); - const qa = M.splat(getValT(T, 10)); - const qb = M.splat(getValT(T, 2)); - for (0..ITERS) |_| { - // Scalar logic branch - _ = if (comptime std.mem.eql(u8, op_name, "add")) - qa.add(qb) - else if (comptime std.mem.eql(u8, op_name, "sub")) - qa.sub(qb) - else if (comptime std.mem.eql(u8, op_name, "mul")) - qa.mul(qb) - else if (comptime std.mem.eql(u8, op_name, "div")) - qa.div(qb) - else if (comptime std.mem.eql(u8, op_name, "abs")) - qa.abs() - else if (comptime std.mem.eql(u8, op_name, "eq")) - qa.eq(qb) - else if (comptime std.mem.eql(u8, op_name, "gt")) - qa.gt(qb) - else - unreachable; - } - const q_end = getTime(); - tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); + for (0..SAMPLES) |_| { + // --- 1. Benchmark Native --- + const n_start = getTime(); + const a = getValT(T, 10); + const b = getValT(T, 2); + for (0..ITERS) |_| { + // Native logic branch + _ = if (comptime std.mem.eql(u8, op_name, "add")) + if (comptime @typeInfo(T) == .int) a +| b else a + b + else if (comptime std.mem.eql(u8, op_name, "sub")) + if (comptime @typeInfo(T) == .int) a -| b else a - b + else if (comptime std.mem.eql(u8, op_name, "mul")) + if (comptime @typeInfo(T) == .int) a *| b else a * b + else if (comptime std.mem.eql(u8, op_name, "div")) + if (comptime @typeInfo(T) == .int) @divTrunc(a, b) else a / b + else if (comptime std.mem.eql(u8, op_name, "abs")) + if (comptime @typeInfo(T) == .int) @abs(a) else @as(T, @abs(a)) + else if (comptime std.mem.eql(u8, op_name, "eq")) + a == b + else if (comptime std.mem.eql(u8, op_name, "gt")) + a > b + else + unreachable; } - }); + const n_end = getTime(); + native_total_ns += @as(f64, @floatFromInt(n_start.durationTo(n_end).toNanoseconds())); + + const v_start = getTime(); + const va = @Vector(1, T){getValT(T, 10)}; + const vb = @Vector(1, T){getValT(T, 2)}; + for (0..ITERS) |_| { + // Native logic branch + _ = if (comptime std.mem.eql(u8, op_name, "add")) + if (comptime @typeInfo(T) == .int) va +| vb else va + vb + else if (comptime std.mem.eql(u8, op_name, "sub")) + if (comptime @typeInfo(T) == .int) va -| vb else va - vb + else if (comptime std.mem.eql(u8, op_name, "mul")) + if (comptime @typeInfo(T) == .int) va *| vb else va * vb + else if (comptime std.mem.eql(u8, op_name, "div")) + if (comptime @typeInfo(T) == .int) @divTrunc(va, vb) else va / vb + else if (comptime std.mem.eql(u8, op_name, "abs")) + if (comptime @typeInfo(T) == .int) @as(T, @intCast(@abs(va[0]))) else @abs(va) + else if (comptime std.mem.eql(u8, op_name, "eq")) + va == vb + else if (comptime std.mem.eql(u8, op_name, "gt")) + va > vb + else + unreachable; + } + const v_end = getTime(); + vector_total_ns += @as(f64, @floatFromInt(v_start.durationTo(v_end).toNanoseconds())); + + // --- 2. Benchmark Scalar --- + const q_start = getTime(); + const qa = M.splat(getValT(T, 10)); + const qb = M.splat(getValT(T, 2)); + for (0..ITERS) |_| { + // Scalar logic branch + _ = if (comptime std.mem.eql(u8, op_name, "add")) + &qa.add(&qb) + else if (comptime std.mem.eql(u8, op_name, "sub")) + &qa.sub(&qb) + else if (comptime std.mem.eql(u8, op_name, "mul")) + &qa.mul(&qb) + else if (comptime std.mem.eql(u8, op_name, "div")) + &qa.div(&qb) + else if (comptime std.mem.eql(u8, op_name, "abs")) + &qa.abs() + else if (comptime std.mem.eql(u8, op_name, "eq")) + &qa.eq(&qb) + else if (comptime std.mem.eql(u8, op_name, "gt")) + &qa.gt(&qb) + else + unreachable; + } + const q_end = getTime(); + tensor_total_ns += @as(f64, @floatFromInt(q_start.durationTo(q_end).toNanoseconds())); + } const avg_n = (native_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS)); const avg_v = (vector_total_ns / SAMPLES) / @as(f64, @floatFromInt(ITERS));