From 91c5c41fc5684ed766d2c5d6a81f033cc9d03f8d Mon Sep 17 00:00:00 2001 From: adrien Date: Sun, 24 May 2026 21:23:20 +0200 Subject: [PATCH] Working comparaison for TensorAlloc --- src/TensorAlloc.zig | 72 ++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/src/TensorAlloc.zig b/src/TensorAlloc.zig index fed9f79..10c3252 100644 --- a/src/TensorAlloc.zig +++ b/src/TensorAlloc.zig @@ -74,7 +74,7 @@ pub fn Tensor( comptime Dest: type, alloc: Allocator, ) !Dest { - if (comptime Self == Dest) return self.*; + if (comptime Self == Dest) return self.copy(alloc); // Run validation checks FIRST before dealing with types if (comptime !dims.eql(Dest.dims)) @@ -191,52 +191,58 @@ pub fn Tensor( return cmpResult(p.l.data.* == p.r.data.*); } - pub inline fn ne(self: *const Self, rhs: anytype) CmpResult { + pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { if (comptime !dims.eql(@TypeOf(rhs).dims)) @compileError("Dimension mismatch in ne."); - const p = resolveScalePair(self, rhs); - return cmpResult(p.l != p.r); + const p = try resolveScalePair(self, alloc, rhs); + defer p.deinit(alloc); + return cmpResult(p.l.data.* != p.r.data.*); } - pub inline fn gt(self: *const Self, rhs: anytype) CmpResult { + pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { if (comptime !dims.eql(@TypeOf(rhs).dims)) @compileError("Dimension mismatch in gt."); - const p = resolveScalePair(self, rhs); - return cmpResult(p.l > p.r); + const p = try resolveScalePair(self, alloc, rhs); + defer p.deinit(alloc); + return cmpResult(p.l.data.* > p.r.data.*); } - pub inline fn gte(self: *const Self, rhs: anytype) CmpResult { + pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { if (comptime !dims.eql(@TypeOf(rhs).dims)) @compileError("Dimension mismatch in gte."); - const p = resolveScalePair(self, rhs); - return cmpResult(p.l >= p.r); + const p = try resolveScalePair(self, alloc, rhs); + defer p.deinit(alloc); + return cmpResult(p.l.data.* >= p.r.data.*); } - pub inline fn lt(self: *const Self, rhs: anytype) CmpResult { + pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { if (comptime !dims.eql(@TypeOf(rhs).dims)) @compileError("Dimension mismatch in lt."); - const p = resolveScalePair(self, rhs); - return cmpResult(p.l < p.r); + const p = try resolveScalePair(self, alloc, rhs); + defer p.deinit(alloc); + return cmpResult(p.l.data.* < p.r.data.*); } - pub inline fn lte(self: *const Self, rhs: anytype) CmpResult { + pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult { if (comptime !dims.eql(@TypeOf(rhs).dims)) @compileError("Dimension mismatch in lte."); - const p = resolveScalePair(self, rhs); - return cmpResult(p.l <= p.r); + const p = try resolveScalePair(self, alloc, rhs); + defer p.deinit(alloc); + return cmpResult(p.l.data.* <= p.r.data.*); } /// True iff every element is equal after scale resolution. - pub inline fn eqAll(self: *const Self, other: anytype) bool { + pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool { if (comptime !dims.eql(@TypeOf(other).dims)) @compileError("Dimension mismatch in eqAll."); - const p = resolveScalePair(self, other); - return @reduce(.And, p.l == p.r); + const p = try resolveScalePair(self, alloc, other); + defer p.deinit(alloc); + return @reduce(.And, p.l.data.* == p.r.data.*); } /// True iff any element differs after scale resolution. - pub inline fn neAll(self: *const Self, other: anytype) bool { - return !self.eqAll(other); + pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool { + return !(try self.eqAll(alloc, other)); } }; } @@ -272,21 +278,21 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" { const m1000 = try Meter.splat(alloc, 1000); const km1 = try KiloMeter.splat(alloc, 1); - // const km2 = try KiloMeter.splat(alloc, 2); + const km2 = try KiloMeter.splat(alloc, 2); try std.testing.expect(try m1000.eq(alloc, km1)); try std.testing.expect(try km1.eq(alloc, m1000)); - // try std.testing.expect(km2.ne(m1000)); - // - // try std.testing.expect(km2.gt(m1000)); - // try std.testing.expect(km2.gt(km1)); - // try std.testing.expect(km1.gte(m1000)); - // try std.testing.expect(km2.gte(m1000)); - // - // try std.testing.expect(m1000.lt(km2)); - // try std.testing.expect(km1.lt(km2)); - // try std.testing.expect(km1.lte(m1000)); - // try std.testing.expect(m1000.lte(km2)); + try std.testing.expect(try km2.ne(alloc, m1000)); + + try std.testing.expect(try km2.gt(alloc, m1000)); + try std.testing.expect(try km2.gt(alloc, km1)); + try std.testing.expect(try km1.gte(alloc, m1000)); + try std.testing.expect(try km2.gte(alloc, m1000)); + + try std.testing.expect(try m1000.lt(alloc, km2)); + try std.testing.expect(try km1.lt(alloc, km2)); + try std.testing.expect(try km1.lte(alloc, m1000)); + try std.testing.expect(try m1000.lte(alloc, km2)); } // test "TensorAlloc | Scalar Add" {