Working TensorAlloc add
This commit is contained in:
parent
91c5c41fc5
commit
7494595db4
@ -65,14 +65,47 @@ pub fn Tensor(
|
|||||||
return new;
|
return new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Element-wise add. Dimensions must match; scales resolve to finer.
|
||||||
|
/// RHS must have the same shape as self, or total == 1 (broadcast).
|
||||||
|
pub inline fn add(self: *const Self, alloc: Allocator, rhs: anytype) !Tensor(
|
||||||
|
T,
|
||||||
|
dims.argsOpt(),
|
||||||
|
sh.finerScales(Self, @TypeOf(rhs)).argsOpt(),
|
||||||
|
shape,
|
||||||
|
) {
|
||||||
|
const RhsType = @TypeOf(rhs);
|
||||||
|
if (comptime !sh.isTensor(RhsType))
|
||||||
|
@compileError("rhs can only be a Tensor ");
|
||||||
|
if (comptime !dims.eql(RhsType.dims))
|
||||||
|
@compileError("Dimension mismatch in add: " ++ dims.str() ++ " vs " ++ RhsType.dims.str());
|
||||||
|
if (comptime RhsType.total != 1 and !sh.shapeEql(shape, RhsType.shape))
|
||||||
|
@compileError("Shape mismatch in add: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
|
const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
|
const l: TargetType = try self.to(alloc, TargetType);
|
||||||
|
defer l.deinit(alloc);
|
||||||
|
const r: TargetType = try rhs.to(alloc, TargetType);
|
||||||
|
defer r.deinit(alloc);
|
||||||
|
|
||||||
|
const result_vec = if (comptime sh.isInt(T))
|
||||||
|
l.data.* +| r.data.*
|
||||||
|
else
|
||||||
|
l.data.* + r.data.*;
|
||||||
|
|
||||||
|
const vec_ptr = try alloc.create(@TypeOf(result_vec));
|
||||||
|
vec_ptr.* = result_vec;
|
||||||
|
|
||||||
|
return TargetType{ .data = vec_ptr };
|
||||||
|
}
|
||||||
|
|
||||||
/// Convert to a compatible Tensor type.
|
/// Convert to a compatible Tensor type.
|
||||||
/// • Dimension mismatch → compile error.
|
/// • Dimension mismatch → compile error.
|
||||||
/// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
|
/// • Dest.shape must equal self.shape, or total == 1 -> splat to Dest shape (scalar pattern).
|
||||||
/// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
|
/// • Scale ratio is computed fully at comptime; only a SIMD multiply at runtime.
|
||||||
pub inline fn to(
|
pub inline fn to(
|
||||||
self: *const Self,
|
self: *const Self,
|
||||||
comptime Dest: type,
|
|
||||||
alloc: Allocator,
|
alloc: Allocator,
|
||||||
|
comptime Dest: type,
|
||||||
) !Dest {
|
) !Dest {
|
||||||
if (comptime Self == Dest) return self.copy(alloc);
|
if (comptime Self == Dest) return self.copy(alloc);
|
||||||
|
|
||||||
@ -82,79 +115,81 @@ pub fn Tensor(
|
|||||||
if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape))
|
if (comptime total != 1 and !sh.shapeEql(shape, Dest.shape))
|
||||||
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
@compileError("Shape mismatch in to: destination type must have the identical shape, or be a scalar.");
|
||||||
|
|
||||||
const vec = if (comptime total == 1 and Dest.total != 1)
|
|
||||||
try Dest.splat(alloc, self.data[0])
|
|
||||||
else
|
|
||||||
try Dest.load(alloc, @ptrCast(self.data));
|
|
||||||
|
|
||||||
const ratio = comptime (scales.getFactor(dims) / Dest.scales.getFactor(Dest.dims));
|
const ratio = comptime (scales.getFactor(dims) / Dest.scales.getFactor(Dest.dims));
|
||||||
const DestT = Dest.ValueType;
|
const DestT = Dest.ValueType;
|
||||||
const DestVec = @Vector(Dest.total, DestT);
|
const DestVec = @Vector(Dest.total, DestT);
|
||||||
|
|
||||||
if (comptime ratio == 1.0 and T == DestT)
|
// 1. Prepare the source vector (handling scalar -> tensor broadcast)
|
||||||
return self.*;
|
const SrcVec = @Vector(Dest.total, T);
|
||||||
|
const src_vec: SrcVec = if (comptime total == 1 and Dest.total != 1)
|
||||||
|
@splat(self.data[0])
|
||||||
|
else
|
||||||
|
self.data.*;
|
||||||
|
|
||||||
// If ratio is 1, handle type conversion correctly based on BOTH source and dest types
|
var result_vec: DestVec = undefined;
|
||||||
|
|
||||||
|
// 2. Perform the vectorized conversion safely
|
||||||
if (comptime ratio == 1.0) {
|
if (comptime ratio == 1.0) {
|
||||||
const T_info = @typeInfo(T);
|
if (comptime T == DestT) {
|
||||||
const Dest_info = @typeInfo(DestT);
|
result_vec = src_vec;
|
||||||
|
|
||||||
vec.data = if (comptime T_info == .int and Dest_info == .int)
|
|
||||||
@as(DestVec, @intCast(vec.data))
|
|
||||||
else if (comptime T_info == .float and Dest_info == .float)
|
|
||||||
@as(DestVec, @floatCast(vec.data))
|
|
||||||
else if (comptime T_info == .int and Dest_info == .float)
|
|
||||||
@as(DestVec, @floatFromInt(vec.data))
|
|
||||||
else if (comptime T_info == .float and Dest_info == .int)
|
|
||||||
@as(DestVec, @intFromFloat(vec.data))
|
|
||||||
else
|
|
||||||
unreachable;
|
|
||||||
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (comptime T == DestT) {
|
|
||||||
if (comptime @typeInfo(T) == .float) {
|
|
||||||
vec.data = vec.data * @as(DestVec, @splat(@as(T, @floatCast(ratio))));
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (comptime ratio >= 1.0) {
|
|
||||||
const mult: T = comptime @intFromFloat(@round(ratio));
|
|
||||||
vec.data.* = vec.data.* *| @as(Vec, @splat(mult));
|
|
||||||
return vec;
|
|
||||||
} else {
|
} else {
|
||||||
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
|
const T_info = @typeInfo(T);
|
||||||
const half: T = comptime @divTrunc(div_val, 2);
|
const Dest_info = @typeInfo(DestT);
|
||||||
|
|
||||||
if (comptime @typeInfo(T).int.signedness == .unsigned) {
|
result_vec = if (comptime T_info == .int and Dest_info == .int)
|
||||||
vec.data = @divTrunc(vec.data + @as(Vec, @splat(half)), @as(Vec, @splat(div_val)));
|
@as(DestVec, @intCast(src_vec))
|
||||||
} else {
|
else if (comptime T_info == .float and Dest_info == .float)
|
||||||
// Vectorized branchless negative handling
|
@as(DestVec, @floatCast(src_vec))
|
||||||
const is_pos = self.data >= @as(Vec, @splat(0));
|
else if (comptime T_info == .int and Dest_info == .float)
|
||||||
const offsets = @select(T, is_pos, @as(Vec, @splat(half)), @as(Vec, @splat(-half)));
|
@as(DestVec, @floatFromInt(src_vec))
|
||||||
vec.data = @divTrunc(vec.data + offsets, @as(Vec, @splat(div_val)));
|
else if (comptime T_info == .float and Dest_info == .int)
|
||||||
}
|
@as(DestVec, @intFromFloat(src_vec))
|
||||||
return vec;
|
else
|
||||||
|
unreachable;
|
||||||
}
|
}
|
||||||
|
} else if (comptime T == DestT) {
|
||||||
|
if (comptime @typeInfo(T) == .float) {
|
||||||
|
result_vec = src_vec * @as(DestVec, @splat(@as(T, @floatCast(ratio))));
|
||||||
|
} else {
|
||||||
|
if (comptime ratio >= 1.0) {
|
||||||
|
const mult: T = comptime @intFromFloat(@round(ratio));
|
||||||
|
result_vec = src_vec *| @as(DestVec, @splat(mult));
|
||||||
|
} else {
|
||||||
|
const div_val: T = comptime @intFromFloat(@round(1.0 / ratio));
|
||||||
|
const half: T = comptime @divTrunc(div_val, 2);
|
||||||
|
|
||||||
|
if (comptime @typeInfo(T).int.signedness == .unsigned) {
|
||||||
|
result_vec = @divTrunc(src_vec + @as(DestVec, @splat(half)), @as(DestVec, @splat(div_val)));
|
||||||
|
} else {
|
||||||
|
// Vectorized branchless negative handling
|
||||||
|
const is_pos = src_vec >= @as(DestVec, @splat(0));
|
||||||
|
const offsets = @select(T, is_pos, @as(DestVec, @splat(half)), @as(DestVec, @splat(-half)));
|
||||||
|
result_vec = @divTrunc(src_vec + offsets, @as(DestVec, @splat(div_val)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Cross-type fully vectorized casting with scales
|
||||||
|
const FVec = @Vector(Dest.total, f64);
|
||||||
|
const float_vec: FVec = switch (comptime @typeInfo(T)) {
|
||||||
|
.float => @floatCast(src_vec),
|
||||||
|
.int => @floatFromInt(src_vec),
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
|
|
||||||
|
const scaled = float_vec * @as(FVec, @splat(ratio));
|
||||||
|
|
||||||
|
result_vec = switch (comptime @typeInfo(DestT)) {
|
||||||
|
.float => @floatCast(scaled),
|
||||||
|
.int => @intFromFloat(@round(scaled)),
|
||||||
|
else => unreachable,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cross-type fully vectorized casting with scales
|
// 3. Allocate once and assign the computed result
|
||||||
const FVec = @Vector(total, f64);
|
const vec_ptr = try alloc.create(DestVec);
|
||||||
const float_vec: FVec = switch (comptime @typeInfo(T)) {
|
vec_ptr.* = result_vec;
|
||||||
.float => @floatCast(vec.data),
|
return Dest{ .data = vec_ptr };
|
||||||
.int => @floatFromInt(vec.data),
|
|
||||||
else => unreachable,
|
|
||||||
};
|
|
||||||
|
|
||||||
const scaled = float_vec * @as(FVec, @splat(ratio));
|
|
||||||
|
|
||||||
vec.data = switch (comptime @typeInfo(DestT)) {
|
|
||||||
.float => @floatCast(scaled),
|
|
||||||
.int => @intFromFloat(@round(scaled)),
|
|
||||||
else => unreachable,
|
|
||||||
};
|
|
||||||
return vec;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const CmpResult = if (total == 1) bool else [total]bool;
|
const CmpResult = if (total == 1) bool else [total]bool;
|
||||||
@ -180,7 +215,7 @@ pub fn Tensor(
|
|||||||
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
@compileError("Shape mismatch in comparison: element-wise operations require identical shapes, or a scalar RHS.");
|
||||||
|
|
||||||
const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
const TargetType = Tensor(T, dims.argsOpt(), sh.finerScales(Self, RhsType).argsOpt(), shape);
|
||||||
return .{ .l = try self.to(TargetType, alloc), .r = try rhs.to(TargetType, alloc) };
|
return .{ .l = try self.to(alloc, TargetType), .r = try rhs.to(alloc, TargetType) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
pub inline fn eq(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
@ -295,29 +330,33 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
|
|||||||
try std.testing.expect(try m1000.lte(alloc, km2));
|
try std.testing.expect(try m1000.lte(alloc, km2));
|
||||||
}
|
}
|
||||||
|
|
||||||
// test "TensorAlloc | Scalar Add" {
|
test "TensorAlloc | Scalar Add" {
|
||||||
// const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
// const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
defer arena.deinit();
|
||||||
// const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
const alloc = arena.allocator();
|
||||||
//
|
|
||||||
// const distance = Meter.splat(10);
|
const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
// const distance2 = Meter.splat(20);
|
const KiloMeter = Tensor(i128, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
// const added = distance.add(distance2);
|
const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
// try std.testing.expectEqual(30, added.data[0]);
|
|
||||||
// try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L));
|
const distance = try Meter.splat(alloc, 10);
|
||||||
//
|
const distance2 = try Meter.splat(alloc, 20);
|
||||||
// const distance3 = KiloMeter.splat(2);
|
const added = try distance.add(alloc, distance2);
|
||||||
// const added2 = distance.add(distance3);
|
try std.testing.expectEqual(30, added.data[0]);
|
||||||
// try std.testing.expectEqual(2010, added2.data[0]);
|
try std.testing.expectEqual(1, @TypeOf(added).dims.get(.L));
|
||||||
//
|
|
||||||
// const added3 = distance3.add(distance).to(KiloMeter);
|
const distance3 = try KiloMeter.splat(alloc, 2);
|
||||||
// try std.testing.expectEqual(2, added3.data[0]);
|
const added2 = try distance.add(alloc, distance3);
|
||||||
//
|
try std.testing.expectEqual(2010, added2.data[0]);
|
||||||
// const distance4 = KiloMeter_f.splat(2);
|
|
||||||
// const added4 = distance4.add(distance).to(KiloMeter_f);
|
const added3 = try (try distance3.add(alloc, distance)).to(alloc, KiloMeter);
|
||||||
// try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001);
|
try std.testing.expectEqual(2, added3.data[0]);
|
||||||
// }
|
|
||||||
//
|
const distance4 = try KiloMeter_f.splat(alloc, 2);
|
||||||
|
const added4 = try (try distance4.add(alloc, distance)).to(alloc, KiloMeter_f);
|
||||||
|
try std.testing.expectApproxEqAbs(2.01, added4.data[0], 0.000001);
|
||||||
|
}
|
||||||
|
|
||||||
// test "TensorAlloc | Scalar Sub" {
|
// test "TensorAlloc | Scalar Sub" {
|
||||||
// const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
// const Meter = Tensor(i128, .{ .L = 1 }, .{}, &.{1});
|
||||||
// const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
// const KiloMeter_f = Tensor(f64, .{ .L = 1 }, .{ .L = .k }, &.{1});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user