diff --git a/src/Tensor.zig b/src/Tensor.zig index a205b81..85be2b1 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -23,10 +23,10 @@ pub fn TensorStatic( comptime shape_: []const comptime_int, ) type { comptime { - if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); - for (shape_) |s| { - if (s == 0) @compileError("Tensor shape dimensions must be strictly >= 1."); - } + if (shape_.len == 0) + @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); + for (shape_) |s| + if (s < 1) @compileError("Tensor shape dimensions must be strictly >= 1."); } @setEvalBranchQuota(100_000_000); @@ -112,7 +112,7 @@ pub fn TensorStatic( @compileError("rhs can only be a Tensor "); if (comptime !dims.eql(RhsType.dims)) @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); - if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) + if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); @@ -127,10 +127,12 @@ pub fn TensorStatic( T, dims.add(@TypeOf(rhs).dims).argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), - shape_, + shape, ) { const RhsType = @TypeOf(rhs); - if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) + if (comptime !isTensor(RhsType)) + @compileError("rhs can only be a Tensor "); + if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in mul: element-wise operations require identical shapes, or a scalar RHS."); const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); @@ -146,10 +148,12 @@ pub fn TensorStatic( T, dims.sub(@TypeOf(rhs).dims).argsOpt(), sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), - shape_, + shape, ) { const RhsType = @TypeOf(rhs); - if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) + if (comptime !isTensor(RhsType)) + @compileError("rhs can only be a Tensor "); + if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in div: element-wise operations require identical shapes, or a scalar RHS."); const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); @@ -169,7 +173,7 @@ pub fn TensorStatic( T, dims.scale(exp).argsOpt(), scales.argsOpt(), - shape_, + shape, ) { if (comptime exp == 0) return .{ .data = @splat(1) }; if (comptime exp == 1) return self; @@ -199,7 +203,7 @@ pub fn TensorStatic( T, dims.div(2).argsOpt(), scales.argsOpt(), - shape_, + shape, ) { if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); @@ -235,7 +239,7 @@ pub fn TensorStatic( // Run validation checks FIRST before dealing with types if (comptime !dims.eql(Dest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ Dest.dims.str()); - if (comptime total != 1 and !sh.shapeEql(shape_, Dest.shape)) + if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape)) @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); const vec = if (comptime total == 1 and Dest.total != 1) @@ -317,10 +321,12 @@ pub fn TensorStatic( /// Resolve both sides to the finer scale, broadcasting shape {1} RHS if needed. inline fn resolveScalePair(self: *const Self, rhs: anytype) struct { l: Vec, r: Vec } { const RhsType = @TypeOf(rhs); - if (comptime RhsType.total != 1 and !sh.shapeEql(shape_, RhsType.shape)) + if (comptime !isTensor(RhsType)) + @compileError("rhs can only be a Tensor "); + if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); - const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape_); + const TargetType = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); return .{ .l = self.to(TargetType).data, .r = rhs.to(TargetType).data }; } @@ -381,49 +387,51 @@ pub fn TensorStatic( pub inline fn contract( self: *const Self, - other: anytype, + rhs: anytype, comptime axis_a: usize, comptime axis_b: usize, ) blk: { - const OT = @TypeOf(other); + const RhsType = @TypeOf(rhs); + if (!isTensor(RhsType)) + @compileError("rhs can only be a Tensor "); if (axis_a >= rank) @compileError("contract: axis_a out of bounds"); - if (axis_b >= OT.rank) @compileError("contract: axis_b out of bounds"); - if (shape_[axis_a] != OT.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes"); + if (axis_b >= RhsType.rank) @compileError("contract: axis_b out of bounds"); + if (shape[axis_a] != RhsType.shape[axis_b]) @compileError("contract: shape mismatch at contraction axes"); - const sa = sh.shapeRemoveAxis(shape_, axis_a); - const sb = sh.shapeRemoveAxis(OT.shape, axis_b); + const sa = sh.shapeRemoveAxis(shape, axis_a); + const sb = sh.shapeRemoveAxis(RhsType.shape, axis_b); const rs_raw = sh.shapeCat(&sa, &sb); const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; break :blk TensorStatic( T, - dims.add(OT.dims).argsOpt(), - sh.finerScales(Self, OT).argsOpt(), + dims.add(RhsType.dims).argsOpt(), + sh.finerScales(Self, RhsType).argsOpt(), rs, ); } { - const OT = @TypeOf(other); - const k: usize = comptime shape_[axis_a]; // contraction dimension + const RhsType = @TypeOf(rhs); + const k: usize = comptime shape[axis_a]; // contraction dimension - const sa = comptime sh.shapeRemoveAxis(shape_, axis_a); - const sb = comptime sh.shapeRemoveAxis(OT.shape, axis_b); + const sa = comptime sh.shapeRemoveAxis(shape, axis_a); + const sb = comptime sh.shapeRemoveAxis(RhsType.shape, axis_b); const rs_raw = comptime sh.shapeCat(&sa, &sb); const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; const ResultType = TensorStatic( T, - dims.add(OT.dims).argsOpt(), - sh.finerScales(Self, OT).argsOpt(), + dims.add(RhsType.dims).argsOpt(), + sh.finerScales(Self, RhsType).argsOpt(), rs, ); - const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), shape_); - const OtherNorm = TensorStatic(T, OT.dims.argsOpt(), sh.finerScales(Self, OT).argsOpt(), OT.shape); + const SelfNorm = TensorStatic(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); + const OtherNorm = TensorStatic(T, RhsType.dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), RhsType.shape); const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; - const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; + const b_data = if (comptime RhsType == OtherNorm) rhs.data else rhs.to(OtherNorm).data; // FAST PATH: Dot Product - if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) { + if (comptime rank == 1 and RhsType.rank == 1 and axis_a == 0 and axis_b == 0) { if (comptime !sh.isInt(T)) { return .{ .data = @splat(@reduce(.Add, a_data * b_data)) }; } else { @@ -438,13 +446,13 @@ pub fn TensorStatic( // --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING --- const a_arr: [total]T = a_data; - const b_arr: [OT.total]T = b_data; + const b_arr: [RhsType.total]T = b_data; // FAST PATH: 2D Matrix Multiplication - if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) { - const rows = shape_[0]; - const cols = OT.shape[1]; - const inner = shape_[1]; + if (comptime rank == 2 and RhsType.rank == 2 and axis_a == 1 and axis_b == 0) { + const rows = shape[0]; + const cols = RhsType.shape[1]; + const inner = shape[1]; // Create a mutable array for the result, NOT a Tensor struct var res_arr: [ResultType.total]T = undefined; @@ -454,7 +462,7 @@ pub fn TensorStatic( var acc: T = 0; for (0..inner) |id| { const a_flat = i * _strides[0] + id * _strides[1]; - const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1]; + const b_flat = id * RhsType.strides_arr[0] + j * RhsType.strides_arr[1]; // Use a_arr and b_arr here if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; @@ -484,9 +492,9 @@ pub fn TensorStatic( var acc: T = 0; for (0..k) |ki| { const a_coords = sh.insertAxis(rank, axis_a, ki, &a_free); - const b_coords = sh.insertAxis(OT.rank, axis_b, ki, &b_free); + const b_coords = sh.insertAxis(RhsType.rank, axis_b, ki, &b_free); const a_flat = sh.encodeFlatCoords(&a_coords, rank, _strides); - const b_flat = sh.encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + const b_flat = sh.encodeFlatCoords(&b_coords, RhsType.rank, RhsType.strides_arr); // Use a_arr and b_arr here if (comptime sh.isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; @@ -509,9 +517,10 @@ pub fn TensorStatic( ) { const RhsType = @TypeOf(rhs); - if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) { + if (!isTensor(RhsType)) + @compileError("rhs can only be a Tensor "); + if (comptime rank != 1 or shape[0] != 3 or RhsType.rank != 1 or RhsType.shape[0] != 3) @compileError("cross product is only defined for 3D vectors (rank-1, length 3)"); - } // Bring both to the same scale (e.g., mm vs m) const p = self.resolveScalePair(rhs); diff --git a/src/benchmark.zig b/src/benchmark.zig index 8b567c5..9b7e042 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -10,27 +10,27 @@ pub fn main(init: std.process.Init) !void { io = init.io; - // try vectorSIMDvsNative(f64, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(f32, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i32, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i64, &stdout_writer.interface); - // try stdout_writer.flush(); - // try vectorSIMDvsNative(i128, &stdout_writer.interface); - // try stdout_writer.flush(); - // - // try bench_Scalar(&stdout_writer.interface); - // try stdout_writer.flush(); + try vectorSIMDvsNative(f64, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(f32, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i32, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i64, &stdout_writer.interface); + try stdout_writer.flush(); + try vectorSIMDvsNative(i128, &stdout_writer.interface); + try stdout_writer.flush(); + + try bench_Scalar(&stdout_writer.interface); + try stdout_writer.flush(); try bench_vsNative(&stdout_writer.interface); try stdout_writer.flush(); - // try bench_crossTypeVsNative(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_Vector(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_HighDimTensor(&stdout_writer.interface); - // try stdout_writer.flush(); + try bench_crossTypeVsNative(&stdout_writer.interface); + try stdout_writer.flush(); + try bench_Vector(&stdout_writer.interface); + try stdout_writer.flush(); + try bench_HighDimTensor(&stdout_writer.interface); + try stdout_writer.flush(); } fn getTime() Io.Timestamp { @@ -128,7 +128,7 @@ fn bench_Scalar(writer: *std.Io.Writer) !void { else if (comptime std.mem.eql(u8, op_name, "gt")) (M.splat(getVal(T, i, 63))).gt(M.splat(getVal(T, i +% 3, 63))) else - (M.splat(getVal(T, i, 63))).mul(3); + (M.splat(getVal(T, i, 63))).mul(M.splat(3)); }, ); }