Working comparaison for TensorAlloc

This commit is contained in:
adrien 2026-05-24 21:23:20 +02:00
parent ba671ee486
commit 91c5c41fc5

View File

@ -74,7 +74,7 @@ pub fn Tensor(
comptime Dest: type,
alloc: Allocator,
) !Dest {
if (comptime Self == Dest) return self.*;
if (comptime Self == Dest) return self.copy(alloc);
// Run validation checks FIRST before dealing with types
if (comptime !dims.eql(Dest.dims))
@ -191,52 +191,58 @@ pub fn Tensor(
return cmpResult(p.l.data.* == p.r.data.*);
}
pub inline fn ne(self: *const Self, rhs: anytype) CmpResult {
pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in ne.");
const p = resolveScalePair(self, rhs);
return cmpResult(p.l != p.r);
const p = try resolveScalePair(self, alloc, rhs);
defer p.deinit(alloc);
return cmpResult(p.l.data.* != p.r.data.*);
}
pub inline fn gt(self: *const Self, rhs: anytype) CmpResult {
pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in gt.");
const p = resolveScalePair(self, rhs);
return cmpResult(p.l > p.r);
const p = try resolveScalePair(self, alloc, rhs);
defer p.deinit(alloc);
return cmpResult(p.l.data.* > p.r.data.*);
}
pub inline fn gte(self: *const Self, rhs: anytype) CmpResult {
pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in gte.");
const p = resolveScalePair(self, rhs);
return cmpResult(p.l >= p.r);
const p = try resolveScalePair(self, alloc, rhs);
defer p.deinit(alloc);
return cmpResult(p.l.data.* >= p.r.data.*);
}
pub inline fn lt(self: *const Self, rhs: anytype) CmpResult {
pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in lt.");
const p = resolveScalePair(self, rhs);
return cmpResult(p.l < p.r);
const p = try resolveScalePair(self, alloc, rhs);
defer p.deinit(alloc);
return cmpResult(p.l.data.* < p.r.data.*);
}
pub inline fn lte(self: *const Self, rhs: anytype) CmpResult {
pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in lte.");
const p = resolveScalePair(self, rhs);
return cmpResult(p.l <= p.r);
const p = try resolveScalePair(self, alloc, rhs);
defer p.deinit(alloc);
return cmpResult(p.l.data.* <= p.r.data.*);
}
/// True iff every element is equal after scale resolution.
pub inline fn eqAll(self: *const Self, other: anytype) bool {
pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
if (comptime !dims.eql(@TypeOf(other).dims))
@compileError("Dimension mismatch in eqAll.");
const p = resolveScalePair(self, other);
return @reduce(.And, p.l == p.r);
const p = try resolveScalePair(self, alloc, other);
defer p.deinit(alloc);
return @reduce(.And, p.l.data.* == p.r.data.*);
}
/// True iff any element differs after scale resolution.
pub inline fn neAll(self: *const Self, other: anytype) bool {
return !self.eqAll(other);
pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
return !(try self.eqAll(alloc, other));
}
};
}
@ -272,21 +278,21 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
const m1000 = try Meter.splat(alloc, 1000);
const km1 = try KiloMeter.splat(alloc, 1);
// const km2 = try KiloMeter.splat(alloc, 2);
const km2 = try KiloMeter.splat(alloc, 2);
try std.testing.expect(try m1000.eq(alloc, km1));
try std.testing.expect(try km1.eq(alloc, m1000));
// try std.testing.expect(km2.ne(m1000));
//
// try std.testing.expect(km2.gt(m1000));
// try std.testing.expect(km2.gt(km1));
// try std.testing.expect(km1.gte(m1000));
// try std.testing.expect(km2.gte(m1000));
//
// try std.testing.expect(m1000.lt(km2));
// try std.testing.expect(km1.lt(km2));
// try std.testing.expect(km1.lte(m1000));
// try std.testing.expect(m1000.lte(km2));
try std.testing.expect(try km2.ne(alloc, m1000));
try std.testing.expect(try km2.gt(alloc, m1000));
try std.testing.expect(try km2.gt(alloc, km1));
try std.testing.expect(try km1.gte(alloc, m1000));
try std.testing.expect(try km2.gte(alloc, m1000));
try std.testing.expect(try m1000.lt(alloc, km2));
try std.testing.expect(try km1.lt(alloc, km2));
try std.testing.expect(try km1.lte(alloc, m1000));
try std.testing.expect(try m1000.lte(alloc, km2));
}
// test "TensorAlloc | Scalar Add" {