diff --git a/src/TensorAlloc.zig b/src/TensorAlloc.zig index 10c3252..67d514e 100644 --- a/src/TensorAlloc.zig +++ b/src/TensorAlloc.zig @@ -65,14 +65,47 @@ pub fn Tensor( return new; } + /// Element-wise add. Dimensions must match; scales resolve to finer. + /// RHS must have the same shape as self, or total == 1 (broadcast). + pub inline fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor( + T, + dims.argsOpt(), + sh.finerScales(Self, @TypeOf(rhs)).argsOpt(), + shape, + ) { + const RhsType = @TypeOf(rhs); + if (comptime !sh.isTensor(RhsType)) + @compileError("rhs can only be a Tensor "); + if (comptime !dims.eql(RhsType.dims)) + @compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str()); + if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape)) + @compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS."); + + const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); + const l: TargetType = try self.to(alloc, TargetType); + defer l.deinit(alloc); + const r: TargetType = try rhs.to(alloc, TargetType); + defer r.deinit(alloc); + + const result_vec = if (comptime sh.isInt(T)) + l.data.* +| r.data.* + else + l.data.* + r.data.*; + + const vec_ptr = try alloc.create(@TypeOf(result_vec)); + vec_ptr.* = result_vec; + + return TargetType{ .data = vec_ptr }; + } + /// Convert to a compatible Tensor type. /// • Dimension mismatch → compile error. /// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern). /// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime. pub inline fn to( self: *const Self, - comptime Dest: type, alloc: Allocator, + comptime Dest: type, ) !Dest { if (comptime Self == Dest) return self.copy(alloc); @@ -82,79 +115,81 @@ pub fn Tensor( if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape)) @compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar."); - const vec = if (comptime total == 1 and Dest.total != 1) - try Dest.splat(alloc, self.data[0]) - else - try Dest.load(alloc, @ptrCast(self.data)); - const ratio = comptime (scales.getFactor(dims) / Dest.scales.getFactor(Dest.dims)); const DestT = Dest.ValueType; const DestVec = @Vector(Dest.total, DestT); - if (comptime ratio == 1.0 and T == DestT) - return self.*; + // 1. Prepare the source vector (handling scalar -> tensor broadcast) + const SrcVec = @Vector(Dest.total, T); + const src_vec: SrcVec = if (comptime total == 1 and Dest.total != 1) + @splat(self.data[0]) + else + self.data.*; - // If ratio is 1, handle type conversion correctly based on BOTH source and dest types + var result_vec: DestVec = undefined; + + // 2. Perform the vectorized conversion safely if (comptime ratio == 1.0) { - const T_info = @typeInfo(T); - const Dest_info = @typeInfo(DestT); - - vec.data = if (comptime T_info == .int and Dest_info == .int) - @as(DestVec, @intCast(vec.data)) - else if (comptime T_info == .float and Dest_info == .float) - @as(DestVec, @floatCast(vec.data)) - else if (comptime T_info == .int and Dest_info == .float) - @as(DestVec, @floatFromInt(vec.data)) - else if (comptime T_info == .float and Dest_info == .int) - @as(DestVec, @intFromFloat(vec.data)) - else - unreachable; - - return vec; - } - - if (comptime T == DestT) { - if (comptime @typeInfo(T) == .float) { - vec.data = vec.data * @as(DestVec, @splat(@as(T, @floatCast(ratio)))); - return vec; - } - - if (comptime ratio >= 1.0) { - const mult: T = comptime @intFromFloat(@round(ratio)); - vec.data.* = vec.data.* *| @as(Vec, @splat(mult)); - return vec; + if (comptime T == DestT) { + result_vec = src_vec; } else { - const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); - const half: T = comptime @divTrunc(div_val, 2); + const T_info = @typeInfo(T); + const Dest_info = @typeInfo(DestT); - if (comptime @typeInfo(T).int.signedness == .unsigned) { - vec.data = @divTrunc(vec.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val))); - } else { - // Vectorized branchless negative handling - const is_pos = self.data >= @as(Vec, @splat(0)); - const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half))); - vec.data = @divTrunc(vec.data + offsets, @as(Vec, @splat(div_val))); - } - return vec; + result_vec = if (comptime T_info == .int and Dest_info == .int) + @as(DestVec, @intCast(src_vec)) + else if (comptime T_info == .float and Dest_info == .float) + @as(DestVec, @floatCast(src_vec)) + else if (comptime T_info == .int and Dest_info == .float) + @as(DestVec, @floatFromInt(src_vec)) + else if (comptime T_info == .float and Dest_info == .int) + @as(DestVec, @intFromFloat(src_vec)) + else + unreachable; } + } else if (comptime T == DestT) { + if (comptime @typeInfo(T) == .float) { + result_vec = src_vec * @as(DestVec, @splat(@as(T, @floatCast(ratio)))); + } else { + if (comptime ratio >= 1.0) { + const mult: T = comptime @intFromFloat(@round(ratio)); + result_vec = src_vec *| @as(DestVec, @splat(mult)); + } else { + const div_val: T = comptime @intFromFloat(@round(1.0 / ratio)); + const half: T = comptime @divTrunc(div_val, 2); + + if (comptime @typeInfo(T).int.signedness == .unsigned) { + result_vec = @divTrunc(src_vec + @as(DestVec, @splat(half)), @as(DestVec, @splat(div_val))); + } else { + // Vectorized branchless negative handling + const is_pos = src_vec >= @as(DestVec, @splat(0)); + const offsets = @select(T, is_pos, @as(DestVec, @splat(half)), @as(DestVec, @splat(-half))); + result_vec = @divTrunc(src_vec + offsets, @as(DestVec, @splat(div_val))); + } + } + } + } else { + // Cross-type fully vectorized casting with scales + const FVec = @Vector(Dest.total, f64); + const float_vec: FVec = switch (comptime @typeInfo(T)) { + .float => @floatCast(src_vec), + .int => @floatFromInt(src_vec), + else => unreachable, + }; + + const scaled = float_vec * @as(FVec, @splat(ratio)); + + result_vec = switch (comptime @typeInfo(DestT)) { + .float => @floatCast(scaled), + .int => @intFromFloat(@round(scaled)), + else => unreachable, + }; } - // Cross-type fully vectorized casting with scales - const FVec = @Vector(total, f64); - const float_vec: FVec = switch (comptime @typeInfo(T)) { - .float => @floatCast(vec.data), - .int => @floatFromInt(vec.data), - else => unreachable, - }; - - const scaled = float_vec * @as(FVec, @splat(ratio)); - - vec.data = switch (comptime @typeInfo(DestT)) { - .float => @floatCast(scaled), - .int => @intFromFloat(@round(scaled)), - else => unreachable, - }; - return vec; + // 3. Allocate once and assign the computed result + const vec_ptr = try alloc.create(DestVec); + vec_ptr.* = result_vec; + return Dest{ .data = vec_ptr }; } const CmpResult = if (total == 1) bool else [total]bool; @@ -180,7 +215,7 @@ pub fn Tensor( @compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS."); const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape); - return .{ .l = try self.to(TargetType, alloc), .r = try rhs.to(TargetType, alloc) }; + return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) }; } pub inline fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { @@ -295,29 +330,33 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" { try std.testing.expect(try m1000.lte(alloc, km2)); } -// test "TensorAlloc | Scalar Add" { -// const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); -// const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); -// const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); -// -// const distance = Meter.splat(10); -// const distance2 = Meter.splat(20); -// const added = distance.add(distance2); -// try std.testing.expectEqual(30, added.data[0]); -// try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); -// -// const distance3 = KiloMeter.splat(2); -// const added2 = distance.add(distance3); -// try std.testing.expectEqual(2010, added2.data[0]); -// -// const added3 = distance3.add(distance).to(KiloMeter); -// try std.testing.expectEqual(2, added3.data[0]); -// -// const distance4 = KiloMeter_f.splat(2); -// const added4 = distance4.add(distance).to(KiloMeter_f); -// try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001); -// } -// +test "TensorAlloc | Scalar Add" { + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + const alloc = arena.allocator(); + + const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); + const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1}); + const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1}); + + const distance = try Meter.splat(alloc, 10); + const distance2 = try Meter.splat(alloc, 20); + const added = try distance.add(alloc, distance2); + try std.testing.expectEqual(30, added.data[0]); + try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L)); + + const distance3 = try KiloMeter.splat(alloc, 2); + const added2 = try distance.add(alloc, distance3); + try std.testing.expectEqual(2010, added2.data[0]); + + const added3 = try (try distance3.add(alloc, distance)).to(alloc, KiloMeter); + try std.testing.expectEqual(2, added3.data[0]); + + const distance4 = try KiloMeter_f.splat(alloc, 2); + const added4 = try (try distance4.add(alloc, distance)).to(alloc, KiloMeter_f); + try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001); +} + // test "TensorAlloc | Scalar Sub" { // const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1}); // const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});