Working comparaison for TensorAlloc

This commit is contained in:
adrien 2026-05-24 21:23:20 +02:00
parent ba671ee486
commit 91c5c41fc5

View File

@ -74,7 +74,7 @@ pub fn Tensor(
comptime Dest: type, comptime Dest: type,
alloc: Allocator, alloc: Allocator,
) !Dest { ) !Dest {
if (comptime Self == Dest) return self.*; if (comptime Self == Dest) return self.copy(alloc);
// Run validation checks FIRST before dealing with types // Run validation checks FIRST before dealing with types
if (comptime !dims.eql(Dest.dims)) if (comptime !dims.eql(Dest.dims))
@ -191,52 +191,58 @@ pub fn Tensor(
return cmpResult(p.l.data.* == p.r.data.*); return cmpResult(p.l.data.* == p.r.data.*);
} }
pub inline fn ne(self: *const Self, rhs: anytype) CmpResult { pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in ne."); @compileError("Dimension mismatch in ne.");
const p = resolveScalePair(self, rhs); const p = try resolveScalePair(self, alloc, rhs);
return cmpResult(p.l != p.r); defer p.deinit(alloc);
return cmpResult(p.l.data.* != p.r.data.*);
} }
pub inline fn gt(self: *const Self, rhs: anytype) CmpResult { pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in gt."); @compileError("Dimension mismatch in gt.");
const p = resolveScalePair(self, rhs); const p = try resolveScalePair(self, alloc, rhs);
return cmpResult(p.l > p.r); defer p.deinit(alloc);
return cmpResult(p.l.data.* > p.r.data.*);
} }
pub inline fn gte(self: *const Self, rhs: anytype) CmpResult { pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in gte."); @compileError("Dimension mismatch in gte.");
const p = resolveScalePair(self, rhs); const p = try resolveScalePair(self, alloc, rhs);
return cmpResult(p.l >= p.r); defer p.deinit(alloc);
return cmpResult(p.l.data.* >= p.r.data.*);
} }
pub inline fn lt(self: *const Self, rhs: anytype) CmpResult { pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in lt."); @compileError("Dimension mismatch in lt.");
const p = resolveScalePair(self, rhs); const p = try resolveScalePair(self, alloc, rhs);
return cmpResult(p.l < p.r); defer p.deinit(alloc);
return cmpResult(p.l.data.* < p.r.data.*);
} }
pub inline fn lte(self: *const Self, rhs: anytype) CmpResult { pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
if (comptime !dims.eql(@TypeOf(rhs).dims)) if (comptime !dims.eql(@TypeOf(rhs).dims))
@compileError("Dimension mismatch in lte."); @compileError("Dimension mismatch in lte.");
const p = resolveScalePair(self, rhs); const p = try resolveScalePair(self, alloc, rhs);
return cmpResult(p.l <= p.r); defer p.deinit(alloc);
return cmpResult(p.l.data.* <= p.r.data.*);
} }
/// True iff every element is equal after scale resolution. /// True iff every element is equal after scale resolution.
pub inline fn eqAll(self: *const Self, other: anytype) bool { pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
if (comptime !dims.eql(@TypeOf(other).dims)) if (comptime !dims.eql(@TypeOf(other).dims))
@compileError("Dimension mismatch in eqAll."); @compileError("Dimension mismatch in eqAll.");
const p = resolveScalePair(self, other); const p = try resolveScalePair(self, alloc, other);
return @reduce(.And, p.l == p.r); defer p.deinit(alloc);
return @reduce(.And, p.l.data.* == p.r.data.*);
} }
/// True iff any element differs after scale resolution. /// True iff any element differs after scale resolution.
pub inline fn neAll(self: *const Self, other: anytype) bool { pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
return !self.eqAll(other); return !(try self.eqAll(alloc, other));
} }
}; };
} }
@ -272,21 +278,21 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
const m1000 = try Meter.splat(alloc, 1000); const m1000 = try Meter.splat(alloc, 1000);
const km1 = try KiloMeter.splat(alloc, 1); const km1 = try KiloMeter.splat(alloc, 1);
// const km2 = try KiloMeter.splat(alloc, 2); const km2 = try KiloMeter.splat(alloc, 2);
try std.testing.expect(try m1000.eq(alloc, km1)); try std.testing.expect(try m1000.eq(alloc, km1));
try std.testing.expect(try km1.eq(alloc, m1000)); try std.testing.expect(try km1.eq(alloc, m1000));
// try std.testing.expect(km2.ne(m1000)); try std.testing.expect(try km2.ne(alloc, m1000));
//
// try std.testing.expect(km2.gt(m1000)); try std.testing.expect(try km2.gt(alloc, m1000));
// try std.testing.expect(km2.gt(km1)); try std.testing.expect(try km2.gt(alloc, km1));
// try std.testing.expect(km1.gte(m1000)); try std.testing.expect(try km1.gte(alloc, m1000));
// try std.testing.expect(km2.gte(m1000)); try std.testing.expect(try km2.gte(alloc, m1000));
//
// try std.testing.expect(m1000.lt(km2)); try std.testing.expect(try m1000.lt(alloc, km2));
// try std.testing.expect(km1.lt(km2)); try std.testing.expect(try km1.lt(alloc, km2));
// try std.testing.expect(km1.lte(m1000)); try std.testing.expect(try km1.lte(alloc, m1000));
// try std.testing.expect(m1000.lte(km2)); try std.testing.expect(try m1000.lte(alloc, km2));
} }
// test "TensorAlloc | Scalar Add" { // test "TensorAlloc | Scalar Add" {