Working comparaison for TensorAlloc
This commit is contained in:
parent
ba671ee486
commit
91c5c41fc5
@ -74,7 +74,7 @@ pub fn Tensor(
|
||||
comptime Dest: type,
|
||||
alloc: Allocator,
|
||||
) !Dest {
|
||||
if (comptime Self == Dest) return self.*;
|
||||
if (comptime Self == Dest) return self.copy(alloc);
|
||||
|
||||
// Run validation checks FIRST before dealing with types
|
||||
if (comptime !dims.eql(Dest.dims))
|
||||
@ -191,52 +191,58 @@ pub fn Tensor(
|
||||
return cmpResult(p.l.data.* == p.r.data.*);
|
||||
}
|
||||
|
||||
pub inline fn ne(self: *const Self, rhs: anytype) CmpResult {
|
||||
pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in ne.");
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l != p.r);
|
||||
const p = try resolveScalePair(self, alloc, rhs);
|
||||
defer p.deinit(alloc);
|
||||
return cmpResult(p.l.data.* != p.r.data.*);
|
||||
}
|
||||
|
||||
pub inline fn gt(self: *const Self, rhs: anytype) CmpResult {
|
||||
pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in gt.");
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l > p.r);
|
||||
const p = try resolveScalePair(self, alloc, rhs);
|
||||
defer p.deinit(alloc);
|
||||
return cmpResult(p.l.data.* > p.r.data.*);
|
||||
}
|
||||
|
||||
pub inline fn gte(self: *const Self, rhs: anytype) CmpResult {
|
||||
pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in gte.");
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l >= p.r);
|
||||
const p = try resolveScalePair(self, alloc, rhs);
|
||||
defer p.deinit(alloc);
|
||||
return cmpResult(p.l.data.* >= p.r.data.*);
|
||||
}
|
||||
|
||||
pub inline fn lt(self: *const Self, rhs: anytype) CmpResult {
|
||||
pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in lt.");
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l < p.r);
|
||||
const p = try resolveScalePair(self, alloc, rhs);
|
||||
defer p.deinit(alloc);
|
||||
return cmpResult(p.l.data.* < p.r.data.*);
|
||||
}
|
||||
|
||||
pub inline fn lte(self: *const Self, rhs: anytype) CmpResult {
|
||||
pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||
@compileError("Dimension mismatch in lte.");
|
||||
const p = resolveScalePair(self, rhs);
|
||||
return cmpResult(p.l <= p.r);
|
||||
const p = try resolveScalePair(self, alloc, rhs);
|
||||
defer p.deinit(alloc);
|
||||
return cmpResult(p.l.data.* <= p.r.data.*);
|
||||
}
|
||||
|
||||
/// True iff every element is equal after scale resolution.
|
||||
pub inline fn eqAll(self: *const Self, other: anytype) bool {
|
||||
pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
||||
if (comptime !dims.eql(@TypeOf(other).dims))
|
||||
@compileError("Dimension mismatch in eqAll.");
|
||||
const p = resolveScalePair(self, other);
|
||||
return @reduce(.And, p.l == p.r);
|
||||
const p = try resolveScalePair(self, alloc, other);
|
||||
defer p.deinit(alloc);
|
||||
return @reduce(.And, p.l.data.* == p.r.data.*);
|
||||
}
|
||||
|
||||
/// True iff any element differs after scale resolution.
|
||||
pub inline fn neAll(self: *const Self, other: anytype) bool {
|
||||
return !self.eqAll(other);
|
||||
pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
||||
return !(try self.eqAll(alloc, other));
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -272,21 +278,21 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
|
||||
|
||||
const m1000 = try Meter.splat(alloc, 1000);
|
||||
const km1 = try KiloMeter.splat(alloc, 1);
|
||||
// const km2 = try KiloMeter.splat(alloc, 2);
|
||||
const km2 = try KiloMeter.splat(alloc, 2);
|
||||
|
||||
try std.testing.expect(try m1000.eq(alloc, km1));
|
||||
try std.testing.expect(try km1.eq(alloc, m1000));
|
||||
// try std.testing.expect(km2.ne(m1000));
|
||||
//
|
||||
// try std.testing.expect(km2.gt(m1000));
|
||||
// try std.testing.expect(km2.gt(km1));
|
||||
// try std.testing.expect(km1.gte(m1000));
|
||||
// try std.testing.expect(km2.gte(m1000));
|
||||
//
|
||||
// try std.testing.expect(m1000.lt(km2));
|
||||
// try std.testing.expect(km1.lt(km2));
|
||||
// try std.testing.expect(km1.lte(m1000));
|
||||
// try std.testing.expect(m1000.lte(km2));
|
||||
try std.testing.expect(try km2.ne(alloc, m1000));
|
||||
|
||||
try std.testing.expect(try km2.gt(alloc, m1000));
|
||||
try std.testing.expect(try km2.gt(alloc, km1));
|
||||
try std.testing.expect(try km1.gte(alloc, m1000));
|
||||
try std.testing.expect(try km2.gte(alloc, m1000));
|
||||
|
||||
try std.testing.expect(try m1000.lt(alloc, km2));
|
||||
try std.testing.expect(try km1.lt(alloc, km2));
|
||||
try std.testing.expect(try km1.lte(alloc, m1000));
|
||||
try std.testing.expect(try m1000.lte(alloc, km2));
|
||||
}
|
||||
|
||||
// test "TensorAlloc | Scalar Add" {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user