Working TensorAlloc add

This commit is contained in:
adrien 2026-05-25 01:44:09 +02:00
parent 91c5c41fc5
commit 7494595db4

View File

@ -65,14 +65,47 @@ pub fn Tensor(
return new;
}
/// Element-wise add. Dimensions must match; scales resolve to finer.
/// RHS must have the same shape as self, or total == 1 (broadcast).
pub inline fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
T,
dims.argsOpt(),
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
shape,
) {
const RhsType = @TypeOf(rhs);
if (comptime !sh.isTensor(RhsType))
@compileError("rhs can only be a Tensor ");
if (comptime !dims.eql(RhsType.dims))
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
const l: TargetType = try self.to(alloc, TargetType);
defer l.deinit(alloc);
const r: TargetType = try rhs.to(alloc, TargetType);
defer r.deinit(alloc);
const result_vec = if (comptime sh.isInt(T))
l.data.* +| r.data.*
else
l.data.* + r.data.*;
const vec_ptr = try alloc.create(@TypeOf(result_vec));
vec_ptr.* = result_vec;
return TargetType{ .data = vec_ptr };
}
/// Convert to a compatible Tensor type.
/// Dimension mismatch compile error.
/// Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
/// Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
pub inline fn to(
self: *const Self,
comptime Dest: type,
alloc: Allocator,
comptime Dest: type,
) !Dest {
if (comptime Self == Dest) return self.copy(alloc);
@ -82,79 +115,81 @@ pub fn Tensor(
if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape))
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
const vec = if (comptime total == 1 and Dest.total != 1)
try Dest.splat(alloc, self.data[0])
else
try Dest.load(alloc, @ptrCast(self.data));
const ratio = comptime (scales.getFactor(dims) / Dest.scales.getFactor(Dest.dims));
const DestT = Dest.ValueType;
const DestVec = @Vector(Dest.total, DestT);
if (comptime ratio == 1.0 and T == DestT)
return self.*;
// 1. Prepare the source vector (handling scalar -> tensor broadcast)
const SrcVec = @Vector(Dest.total, T);
const src_vec: SrcVec = if (comptime total == 1 and Dest.total != 1)
@splat(self.data[0])
else
self.data.*;
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
var result_vec: DestVec = undefined;
// 2. Perform the vectorized conversion safely
if (comptime ratio == 1.0) {
if (comptime T == DestT) {
result_vec = src_vec;
} else {
const T_info = @typeInfo(T);
const Dest_info = @typeInfo(DestT);
vec.data = if (comptime T_info == .int and Dest_info == .int)
@as(DestVec, @intCast(vec.data))
result_vec = if (comptime T_info == .int and Dest_info == .int)
@as(DestVec, @intCast(src_vec))
else if (comptime T_info == .float and Dest_info == .float)
@as(DestVec, @floatCast(vec.data))
@as(DestVec, @floatCast(src_vec))
else if (comptime T_info == .int and Dest_info == .float)
@as(DestVec, @floatFromInt(vec.data))
@as(DestVec, @floatFromInt(src_vec))
else if (comptime T_info == .float and Dest_info == .int)
@as(DestVec, @intFromFloat(vec.data))
@as(DestVec, @intFromFloat(src_vec))
else
unreachable;
return vec;
}
if (comptime T == DestT) {
} else if (comptime T == DestT) {
if (comptime @typeInfo(T) == .float) {
vec.data = vec.data * @as(DestVec, @splat(@as(T, @floatCast(ratio))));
return vec;
}
result_vec = src_vec * @as(DestVec, @splat(@as(T, @floatCast(ratio))));
} else {
if (comptime ratio >= 1.0) {
const mult: T = comptime @intFromFloat(@round(ratio));
vec.data.* = vec.data.* *| @as(Vec, @splat(mult));
return vec;
result_vec = src_vec *| @as(DestVec, @splat(mult));
} else {
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
const half: T = comptime @divTrunc(div_val, 2);
if (comptime @typeInfo(T).int.signedness == .unsigned) {
vec.data = @divTrunc(vec.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val)));
result_vec = @divTrunc(src_vec + @as(DestVec, @splat(half)), @as(DestVec, @splat(div_val)));
} else {
// Vectorized branchless negative handling
const is_pos = self.data >= @as(Vec, @splat(0));
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
vec.data = @divTrunc(vec.data + offsets, @as(Vec, @splat(div_val)));
}
return vec;
const is_pos = src_vec >= @as(DestVec, @splat(0));
const offsets = @select(T, is_pos, @as(DestVec, @splat(half)), @as(DestVec, @splat(-half)));
result_vec = @divTrunc(src_vec + offsets, @as(DestVec, @splat(div_val)));
}
}
}
} else {
// Cross-type fully vectorized casting with scales
const FVec = @Vector(total, f64);
const FVec = @Vector(Dest.total, f64);
const float_vec: FVec = switch (comptime @typeInfo(T)) {
.float => @floatCast(vec.data),
.int => @floatFromInt(vec.data),
.float => @floatCast(src_vec),
.int => @floatFromInt(src_vec),
else => unreachable,
};
const scaled = float_vec * @as(FVec, @splat(ratio));
vec.data = switch (comptime @typeInfo(DestT)) {
result_vec = switch (comptime @typeInfo(DestT)) {
.float => @floatCast(scaled),
.int => @intFromFloat(@round(scaled)),
else => unreachable,
};
return vec;
}
// 3. Allocate once and assign the computed result
const vec_ptr = try alloc.create(DestVec);
vec_ptr.* = result_vec;
return Dest{ .data = vec_ptr };
}
const CmpResult = if (total == 1) bool else [total]bool;
@ -180,7 +215,7 @@ pub fn Tensor(
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
return .{ .l = try self.to(TargetType, alloc), .r = try rhs.to(TargetType, alloc) };
return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) };
}
pub inline fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
@ -295,29 +330,33 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
try std.testing.expect(try m1000.lte(alloc, km2));
}
// test "TensorAlloc | Scalar Add" {
// const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
// const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
// const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
//
// const distance = Meter.splat(10);
// const distance2 = Meter.splat(20);
// const added = distance.add(distance2);
// try std.testing.expectEqual(30, added.data[0]);
// try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L));
//
// const distance3 = KiloMeter.splat(2);
// const added2 = distance.add(distance3);
// try std.testing.expectEqual(2010, added2.data[0]);
//
// const added3 = distance3.add(distance).to(KiloMeter);
// try std.testing.expectEqual(2, added3.data[0]);
//
// const distance4 = KiloMeter_f.splat(2);
// const added4 = distance4.add(distance).to(KiloMeter_f);
// try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001);
// }
//
test "TensorAlloc | Scalar Add" {
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
const alloc = arena.allocator();
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
const distance = try Meter.splat(alloc, 10);
const distance2 = try Meter.splat(alloc, 20);
const added = try distance.add(alloc, distance2);
try std.testing.expectEqual(30, added.data[0]);
try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L));
const distance3 = try KiloMeter.splat(alloc, 2);
const added2 = try distance.add(alloc, distance3);
try std.testing.expectEqual(2010, added2.data[0]);
const added3 = try (try distance3.add(alloc, distance)).to(alloc, KiloMeter);
try std.testing.expectEqual(2, added3.data[0]);
const distance4 = try KiloMeter_f.splat(alloc, 2);
const added4 = try (try distance4.add(alloc, distance)).to(alloc, KiloMeter_f);
try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001);
}
// test "TensorAlloc | Scalar Sub" {
// const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
// const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});