diff --git a/src/Dimensions.zig b/src/Dimensions.zig index e91e78d..ef37fe1 100644 --- a/src/Dimensions.zig +++ b/src/Dimensions.zig @@ -51,17 +51,17 @@ data: std.EnumArray(Dimension, comptime_int), /// Unspecified dimensions default to 0. pub fn init(comptime init_val: ArgOpts) Self { var s = Self{ .data = std.EnumArray(Dimension, comptime_int).initFill(0) }; - inline for (std.meta.fields(@TypeOf(init_val))) |f| + for (std.meta.fields(@TypeOf(init_val))) |f| s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); return s; } pub fn initFill(comptime val: comptime_int) Self { - return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; + comptime return .{ .data = std.EnumArray(Dimension, comptime_int).initFill(val) }; } pub fn get(comptime self: Self, comptime key: Dimension) comptime_int { - return self.data.get(key); + comptime return self.data.get(key); } pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void { @@ -70,40 +70,40 @@ pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: i8) void pub fn argsOpt(self: Self) ArgOpts { var args: ArgOpts = undefined; - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| @field(args, @tagName(d)) = self.get(d); - return args; + comptime return args; } /// Add exponents component-wise. Used internally by `mul`. pub fn add(comptime a: Self, comptime b: Self) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) + b.get(d)); - return result; + comptime return result; } /// Subtract exponents component-wise. Used internally by `div`. pub fn sub(comptime a: Self, comptime b: Self) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) - b.get(d)); - return result; + comptime return result; } /// Multiply exponents by a scalar integer. Used internally by `pow` in Scalar. pub fn scale(comptime a: Self, comptime exp: comptime_int) Self { var result = Self.initFill(0); - inline for (std.enums.values(Dimension)) |d| + for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) * exp); - return result; + comptime return result; } pub fn div(comptime a: Self, comptime exp: comptime_int) Self { var result = Self.initFill(0); inline for (std.enums.values(Dimension)) |d| result.set(d, a.get(d) / exp); - return result; + comptime return result; } /// Returns true if every dimension exponent is equal. Used to enforce type compatibility in `add`, `sub`, `to`. diff --git a/src/Scales.zig b/src/Scales.zig index 06031fb..ea6ff08 100644 --- a/src/Scales.zig +++ b/src/Scales.zig @@ -65,7 +65,7 @@ pub const UnitScale = enum(isize) { } pub inline fn getFactor(self: @This()) comptime_float { - return comptime switch (self) { + comptime return switch (self) { // Standard SI Exponents inline .P, .T, .G, .M, .k, .h, .da, .none, .d, .c, .m, .u, .n, .p, .f => std.math.pow(f64, 10.0, @floatFromInt(@intFromEnum(self))), @@ -83,7 +83,7 @@ pub const UnitScale = enum(isize) { inline .lb => 453.59237, inline .st => 6350.29318, - inline else => @floatFromInt(@intFromEnum(self)), + else => @floatFromInt(@intFromEnum(self)), }; } }; @@ -103,11 +103,11 @@ pub fn init(comptime init_val: ArgOpts) Self { else s.data.set(@field(Dimension, f.name), @field(init_val, f.name)); } - return s; + return comptime s; } pub fn initFill(comptime val: UnitScale) Self { - return comptime .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; + comptime return .{ .data = std.EnumArray(Dimension, UnitScale).initFill(val) }; } pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { @@ -115,7 +115,7 @@ pub fn get(comptime self: Self, comptime key: Dimension) UnitScale { } pub fn set(comptime self: *Self, comptime key: Dimension, comptime val: UnitScale) void { - comptime self.data.set(key, val); + self.data.set(key, val); } pub fn argsOpt(self: Self) ArgOpts { @@ -144,5 +144,5 @@ pub inline fn getFactor(comptime s: Self, comptime d: Dimensions) comptime_float factor /= base; } } - return comptime factor; + comptime return factor; } diff --git a/src/Tensor.zig b/src/Tensor.zig index ff97dc0..1cc3a91 100644 --- a/src/Tensor.zig +++ b/src/Tensor.zig @@ -8,14 +8,14 @@ const Dimension = Dimensions.Dimension; // Comptime utilities // ───────────────────────────────────────────────────────────────────────────── -pub fn shapeTotal(comptime shape: []const usize) usize { - var t: usize = 1; +pub fn shapeTotal(comptime shape: []const comptime_int) usize { + var t: comptime_int = 1; for (shape) |s| t *= s; return t; } /// Check if two shapes are strictly identical. -pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { +pub fn shapeEql(comptime a: []const comptime_int, comptime b: []const comptime_int) bool { if (a.len != b.len) return false; for (a, 0..) |v, i| if (v != b[i]) return false; @@ -25,21 +25,21 @@ pub fn shapeEql(comptime a: []const usize, comptime b: []const usize) bool { /// Row-major (C-order) strides: strides[i] = product(shape[i+1..]). /// e.g. shape {3, 4} → strides {4, 1} /// shape {2, 3, 4} → strides {12, 4, 1} -pub fn shapeStrides(comptime shape: []const usize) [shape.len]usize { - var st: [shape.len]usize = undefined; +pub fn shapeStrides(comptime shape: []const comptime_int) [shape.len]comptime_int { + var st: [shape.len]comptime_int = undefined; if (shape.len == 0) return st; st[shape.len - 1] = 1; if (shape.len > 1) { - var i: usize = shape.len - 1; + var i: comptime_int = shape.len - 1; while (i > 0) : (i -= 1) st[i - 1] = st[i] * shape[i]; } return st; } /// Return a copy of `shape` with the element at `axis` removed. -pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [shape.len - 1]usize { - var out: [shape.len - 1]usize = undefined; - var j: usize = 0; +pub fn shapeRemoveAxis(comptime shape: []const comptime_int, comptime axis: comptime_int) [shape.len - 1]comptime_int { + var out: [shape.len - 1]comptime_int = undefined; + var j: comptime_int = 0; for (shape, 0..) |v, i| { if (i != axis) { out[j] = v; @@ -50,8 +50,8 @@ pub fn shapeRemoveAxis(comptime shape: []const usize, comptime axis: usize) [sha } /// Concatenate two compile-time slices. -pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b.len]usize { - var out: [a.len + b.len]usize = undefined; +pub fn shapeCat(comptime a: []const comptime_int, comptime b: []const comptime_int) [a.len + b.len]comptime_int { + var out: [a.len + b.len]comptime_int = undefined; for (a, 0..) |v, i| out[i] = v; for (b, 0..) |v, i| out[a.len + i] = v; return out; @@ -60,11 +60,11 @@ pub fn shapeCat(comptime a: []const usize, comptime b: []const usize) [a.len + b /// Decode a flat row-major index into N-D coordinates. /// Called only in comptime contexts (all arguments are comptime). pub fn decodeFlatCoords( - comptime flat: usize, - comptime n: usize, - comptime strd: [n]usize, + comptime flat: comptime_int, + comptime n: comptime_int, + comptime strd: [n]comptime_int, ) [n]usize { - var coords: [n]usize = undefined; + var coords: [n]comptime_int = undefined; var tmp = flat; for (0..n) |i| { coords[i] = if (strd[i] == 0) 0 else tmp / strd[i]; @@ -116,7 +116,7 @@ fn finerScales(comptime T1: type, comptime T2: type) Scales { const s1: Scales = T1.scales; const s2: Scales = T2.scales; comptime var out = Scales.initFill(.none); - inline for (std.enums.values(Dimension)) |dim| { + for (std.enums.values(Dimension)) |dim| { const scale1 = comptime s1.get(dim); const scale2 = comptime s2.get(dim); out.set(dim, if (comptime d1.get(dim) == 0 and d2.get(dim) == 0) @@ -201,7 +201,7 @@ pub fn Tensor( comptime T: type, comptime d_opt: Dimensions.ArgOpts, comptime s_opt: Scales.ArgOpts, - comptime shape_: []const usize, + comptime shape_: []const comptime_int, ) type { comptime { if (shape_.len == 0) @compileError("Tensor shape must have at least 1 dimension (rank >= 1)."); @@ -223,10 +223,10 @@ pub fn Tensor( pub const ValueType: type = T; pub const dims: Dimensions = Dimensions.init(d_opt); pub const scales: Scales = Scales.init(s_opt); - pub const shape: []const usize = shape_; - pub const rank: usize = shape_.len; - pub const total: usize = _total; - pub const strides_arr: [shape_.len]usize = _strides; + pub const shape: []const comptime_int = shape_; + pub const rank: comptime_int = shape_.len; + pub const total: comptime_int = _total; + pub const strides_arr: [shape_.len]comptime_int = _strides; pub const ISTENSOR = true; /// Convert N-D coords (row-major) to flat index — fully comptime. @@ -359,9 +359,7 @@ pub fn Tensor( const rr_base = if (comptime RhsType == RhsNorm) rhs_q else rhs_q.to(RhsNorm); const rr: Vec = broadcastToVec(RhsNorm, rr_base); if (comptime isInt(T)) { - var result: Vec = undefined; - inline for (0..total) |i| result[i] = @divTrunc(l[i], rr[i]); - return .{ .data = result }; + return .{ .data = @divTrunc(l, rr) }; } else { return .{ .data = l / rr }; } @@ -379,19 +377,27 @@ pub fn Tensor( scales.argsOpt(), shape_, ) { - if (comptime isInt(T)) { - var result: Vec = undefined; - inline for (0..total) |i| - result[i] = std.math.powi(T, self.data[i], exp) catch std.math.maxInt(T); - return .{ .data = result }; - } else { - const abs_exp = comptime @abs(exp); - var result: Vec = @splat(1); - comptime var i = 0; - inline while (i < abs_exp) : (i += 1) result *= self.data; - if (comptime exp < 0) result = @as(Vec, @splat(1)) / result; - return .{ .data = result }; + if (comptime exp == 0) return .{ .data = @splat(1) }; + if (comptime exp == 1) return self; + + var base = self.data; + var result: Vec = @splat(1); + comptime var e = @abs(exp); + + // $O(\log n)$ Exponentiation by squaring applied to the entire vector + inline while (e > 0) { + if (e % 2 == 1) { + result = if (comptime isInt(T)) result *| base else result * base; + } + e /= 2; + if (e > 0) { + base = if (comptime isInt(T)) base *| base else base * base; + } } + if (comptime !isInt(T) and exp < 0) { + result = @as(Vec, @splat(1)) / result; + } + return .{ .data = result }; } /// Square root of every element. All dimension exponents must be even. @@ -404,15 +410,16 @@ pub fn Tensor( if (comptime !dims.isSquare()) @compileError("Cannot take sqrt of " ++ dims.str() ++ ": exponents must be even."); if (comptime @typeInfo(T) == .float) { - return .{ .data = @sqrt(self.data) }; + return .{ .data = @sqrt(self.data) }; // Float is natively vectorized! } else { - var result: Vec = undefined; + const arr: [total]T = self.data; // Add this! + var res_arr: [total]T = undefined; const UnsignedT = @Int(.unsigned, @typeInfo(T).int.bits); - inline for (0..total) |i| { - const v = self.data[i]; - result[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); + for (0..total) |i| { + const v = arr[i]; + res_arr[i] = if (v < 0) 0 else @as(T, @intCast(std.math.sqrt(@as(UnsignedT, @intCast(v))))); } - return .{ .data = result }; + return .{ .data = res_arr }; } } @@ -433,28 +440,34 @@ pub fn Tensor( if (comptime Self == ActualDest) return self; - const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); - const DestT = ActualDest.ValueType; - const DestVec = @Vector(total, DestT); - - // If ratio is 1, just handle type conversion - if (comptime ratio == 1.0) { - if (comptime T == DestT) return .{ .data = self.data }; - return .{ - .data = switch (comptime @typeInfo(DestT)) { - .float => @floatFromInt(self.data), // or @floatCast - .int => @intFromFloat(self.data), // or @intCast - else => unreachable, - }, - }; - } - + // Run validation checks FIRST before dealing with types if (comptime !dims.eql(ActualDest.dims)) @compileError("Dimension mismatch in to: " ++ dims.str() ++ " vs " ++ ActualDest.dims.str()); if (comptime Dest.total != 1 and !shapeEql(shape_, Dest.shape)) @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); - if (comptime Self == ActualDest) return self; + const ratio = comptime (scales.getFactor(dims) / ActualDest.scales.getFactor(ActualDest.dims)); + const DestT = ActualDest.ValueType; + const DestVec = @Vector(total, DestT); + + // If ratio is 1, handle type conversion correctly based on BOTH source and dest types + if (comptime ratio == 1.0) { + const T_info = @typeInfo(T); + const Dest_info = @typeInfo(DestT); + + return .{ + .data = if (comptime T_info == .int and Dest_info == .int) + @as(DestVec, @intCast(self.data)) + else if (comptime T_info == .float and Dest_info == .float) + @as(DestVec, @floatCast(self.data)) + else if (comptime T_info == .int and Dest_info == .float) + @as(DestVec, @floatFromInt(self.data)) + else if (comptime T_info == .float and Dest_info == .int) + @as(DestVec, @intFromFloat(self.data)) // Or @intFromFloat(@round(self.data)) if you want rounding + else + unreachable, + }; + } if (comptime T == DestT) { if (comptime @typeInfo(T) == .float) @@ -466,33 +479,33 @@ pub fn Tensor( } else { const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); const half: T = comptime @divTrunc(div_val, 2); - var result: DestVec = undefined; - inline for (0..total) |i| { - const val = self.data[i]; - result[i] = if (val >= 0) - @divTrunc(val + half, div_val) - else - @divTrunc(val - half, div_val); + + if (comptime @typeInfo(T).int.signedness == .unsigned) { + return .{ .data = @divTrunc(self.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))) }; + } else { + // Vectorized branchless negative handling + const is_pos = self.data >= @as(Vec, @splat(0)); + const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half))); + return .{ .data = @divTrunc(self.data + offsets, @as(Vec, @splat(div_val))) }; } - return .{ .data = result }; } } - var result: DestVec = undefined; - inline for (0..total) |i| { - const float_val: f64 = switch (comptime @typeInfo(T)) { - .float => @floatCast(self.data[i]), - .int => @floatFromInt(self.data[i]), - else => unreachable, - }; - const scaled = float_val * ratio; - result[i] = switch (comptime @typeInfo(DestT)) { - .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, - }; - } - return .{ .data = result }; + // Cross-type fully vectorized casting with scales + const FVec = @Vector(total, f64); + const float_vec: FVec = switch (comptime @typeInfo(T)) { + .float => @floatCast(self.data), + .int => @floatFromInt(self.data), + else => unreachable, + }; + + const scaled = float_vec * @as(FVec, @splat(ratio)); + + return switch (comptime @typeInfo(DestT)) { + .float => .{ .data = @floatCast(scaled) }, + .int => .{ .data = @intFromFloat(@round(scaled)) }, + else => unreachable, + }; } const CmpResult = if (total == 1) bool else [total]bool; @@ -592,7 +605,7 @@ pub fn Tensor( const sa = shapeRemoveAxis(shape_, axis_a); const sb = shapeRemoveAxis(OT.shape, axis_b); const rs_raw = shapeCat(&sa, &sb); - const rs: []const usize = if (rs_raw.len == 0) &.{1} else &rs_raw; + const rs: []const comptime_int = if (rs_raw.len == 0) &.{1} else &rs_raw; break :blk Tensor( T, dims.add(OT.dims).argsOpt(), @@ -606,7 +619,7 @@ pub fn Tensor( const sa = comptime shapeRemoveAxis(shape_, axis_a); const sb = comptime shapeRemoveAxis(OT.shape, axis_b); const rs_raw = comptime shapeCat(&sa, &sb); - const rs: []const usize = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; + const rs: []const comptime_int = comptime if (rs_raw.len == 0) &.{1} else &rs_raw; const ResultType = Tensor( T, @@ -617,34 +630,85 @@ pub fn Tensor( const SelfNorm = Tensor(T, dims.argsOpt(), finerScales(Self, OT).argsOpt(), shape_); const OtherNorm = Tensor(T, OT.dims.argsOpt(), finerScales(Self, OT).argsOpt(), OT.shape); + const a_data = if (comptime Self == SelfNorm) self.data else self.to(SelfNorm).data; const b_data = if (comptime OT == OtherNorm) other.data else other.to(OtherNorm).data; + // FAST PATH: Dot Product + if (comptime rank == 1 and OT.rank == 1 and axis_a == 0 and axis_b == 0) { + if (comptime !isInt(T)) { + return .{ .data = @splat(@reduce(.Add, a_data * b_data)) }; + } else { + // For integers, we do a vectorized saturating multiply, + // then convert to an array to do a saturating sum + const mul_arr: [total]T = a_data *| b_data; + var acc: T = 0; + for (mul_arr) |val| acc +|= val; + return .{ .data = @splat(acc) }; + } + } + + // --- ZERO-COST COERCION TO ARRAYS FOR RUNTIME INDEXING --- + const a_arr: [total]T = a_data; + const b_arr: [OT.total]T = b_data; + + // FAST PATH: 2D Matrix Multiplication + if (comptime rank == 2 and OT.rank == 2 and axis_a == 1 and axis_b == 0) { + const rows = shape_[0]; + const cols = OT.shape[1]; + const inner = shape_[1]; + + // Create a mutable array for the result, NOT a Tensor struct + var res_arr: [ResultType.total]T = undefined; + + for (0..rows) |i| { + for (0..cols) |j| { + var acc: T = 0; + for (0..inner) |id| { + const a_flat = i * _strides[0] + id * _strides[1]; + const b_flat = id * OT.strides_arr[0] + j * OT.strides_arr[1]; + + // Use a_arr and b_arr here + if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; + } + // Write to the array + res_arr[i * cols + j] = acc; + } + } + // Return the initialized Tensor struct + return .{ .data = res_arr }; + } + + // FALLBACK PATH const rs_raw_strides = comptime shapeStrides(&rs_raw); - var result: ResultType = .{ .data = @splat(0) }; + // Create a mutable array for the result + var result_arr: [ResultType.total]T = undefined; - inline for (0..ResultType.total) |res_flat| { - const res_coords = comptime decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); + for (0..ResultType.total) |res_flat| { + const res_coords = decodeFlatCoords(res_flat, rs_raw.len, rs_raw_strides); - const a_free: [sa.len]usize = comptime res_coords[0..sa.len].*; - const b_free: [sb.len]usize = comptime res_coords[sa.len..].*; + var a_free: [sa.len]usize = undefined; + for (0..sa.len) |i| a_free[i] = res_coords[i]; + var b_free: [sb.len]usize = undefined; + for (0..sb.len) |i| b_free[i] = res_coords[sa.len + i]; var acc: T = 0; - inline for (0..k) |ki| { - const a_coords = comptime insertAxis(rank, axis_a, ki, &a_free); - const b_coords = comptime insertAxis(OT.rank, axis_b, ki, &b_free); - const a_flat = comptime encodeFlatCoords(&a_coords, rank, _strides); - const b_flat = comptime encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); + for (0..k) |ki| { + const a_coords = insertAxis(rank, axis_a, ki, &a_free); + const b_coords = insertAxis(OT.rank, axis_b, ki, &b_free); + const a_flat = encodeFlatCoords(&a_coords, rank, _strides); + const b_flat = encodeFlatCoords(&b_coords, OT.rank, OT.strides_arr); - if (comptime isInt(T)) - acc +|= a_data[a_flat] *| b_data[b_flat] - else - acc += a_data[a_flat] * b_data[b_flat]; + // Use a_arr and b_arr here + if (comptime isInt(T)) acc +|= a_arr[a_flat] *| b_arr[b_flat] else acc += a_arr[a_flat] * b_arr[b_flat]; } - result.data[res_flat] = acc; + // Write to the array + result_arr[res_flat] = acc; } - return result; + + // Return the initialized Tensor struct + return .{ .data = result_arr }; } /// 3D Cross Product. Only defined for Rank-1 tensors of length 3. @@ -724,7 +788,8 @@ pub fn Tensor( } } else { try writer.writeAll("("); - inline for (0..total) |i| { + const max_to_print = 6; + inline for (0..@min(total, max_to_print)) |i| { if (i > 0) try writer.writeAll(", "); switch (@typeInfo(T)) { .float, .comptime_float => try writer.printFloat(self.data[i], options), @@ -736,6 +801,8 @@ pub fn Tensor( }), else => unreachable, } + if (comptime i == max_to_print - 1 and total != max_to_print - 1) + try writer.writeAll(", ..."); } try writer.writeAll(")"); } diff --git a/src/benchmark.zig b/src/benchmark.zig index c50b3a9..488e494 100644 --- a/src/benchmark.zig +++ b/src/benchmark.zig @@ -23,10 +23,10 @@ pub fn main(init: std.process.Init) !void { try bench_Scalar(&stdout_writer.interface); try stdout_writer.flush(); - // try bench_vsNative(&stdout_writer.interface); - // try stdout_writer.flush(); - // try bench_crossTypeVsNative(&stdout_writer.interface); - // try stdout_writer.flush(); + try bench_vsNative(&stdout_writer.interface); + try stdout_writer.flush(); + try bench_crossTypeVsNative(&stdout_writer.interface); + try stdout_writer.flush(); try bench_Vector(&stdout_writer.interface); try stdout_writer.flush(); }