Working comparaison for TensorAlloc
This commit is contained in:
parent
ba671ee486
commit
91c5c41fc5
@ -74,7 +74,7 @@ pub fn Tensor(
|
|||||||
comptime Dest: type,
|
comptime Dest: type,
|
||||||
alloc: Allocator,
|
alloc: Allocator,
|
||||||
) !Dest {
|
) !Dest {
|
||||||
if (comptime Self == Dest) return self.*;
|
if (comptime Self == Dest) return self.copy(alloc);
|
||||||
|
|
||||||
// Run validation checks FIRST before dealing with types
|
// Run validation checks FIRST before dealing with types
|
||||||
if (comptime !dims.eql(Dest.dims))
|
if (comptime !dims.eql(Dest.dims))
|
||||||
@ -191,52 +191,58 @@ pub fn Tensor(
|
|||||||
return cmpResult(p.l.data.* == p.r.data.*);
|
return cmpResult(p.l.data.* == p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn ne(self: *const Self, rhs: anytype) CmpResult {
|
pub inline fn ne(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in ne.");
|
@compileError("Dimension mismatch in ne.");
|
||||||
const p = resolveScalePair(self, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
return cmpResult(p.l != p.r);
|
defer p.deinit(alloc);
|
||||||
|
return cmpResult(p.l.data.* != p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn gt(self: *const Self, rhs: anytype) CmpResult {
|
pub inline fn gt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in gt.");
|
@compileError("Dimension mismatch in gt.");
|
||||||
const p = resolveScalePair(self, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
return cmpResult(p.l > p.r);
|
defer p.deinit(alloc);
|
||||||
|
return cmpResult(p.l.data.* > p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn gte(self: *const Self, rhs: anytype) CmpResult {
|
pub inline fn gte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in gte.");
|
@compileError("Dimension mismatch in gte.");
|
||||||
const p = resolveScalePair(self, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
return cmpResult(p.l >= p.r);
|
defer p.deinit(alloc);
|
||||||
|
return cmpResult(p.l.data.* >= p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn lt(self: *const Self, rhs: anytype) CmpResult {
|
pub inline fn lt(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in lt.");
|
@compileError("Dimension mismatch in lt.");
|
||||||
const p = resolveScalePair(self, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
return cmpResult(p.l < p.r);
|
defer p.deinit(alloc);
|
||||||
|
return cmpResult(p.l.data.* < p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn lte(self: *const Self, rhs: anytype) CmpResult {
|
pub inline fn lte(self: *const Self, alloc: Allocator, rhs: anytype) !CmpResult {
|
||||||
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
if (comptime !dims.eql(@TypeOf(rhs).dims))
|
||||||
@compileError("Dimension mismatch in lte.");
|
@compileError("Dimension mismatch in lte.");
|
||||||
const p = resolveScalePair(self, rhs);
|
const p = try resolveScalePair(self, alloc, rhs);
|
||||||
return cmpResult(p.l <= p.r);
|
defer p.deinit(alloc);
|
||||||
|
return cmpResult(p.l.data.* <= p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// True iff every element is equal after scale resolution.
|
/// True iff every element is equal after scale resolution.
|
||||||
pub inline fn eqAll(self: *const Self, other: anytype) bool {
|
pub inline fn eqAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
||||||
if (comptime !dims.eql(@TypeOf(other).dims))
|
if (comptime !dims.eql(@TypeOf(other).dims))
|
||||||
@compileError("Dimension mismatch in eqAll.");
|
@compileError("Dimension mismatch in eqAll.");
|
||||||
const p = resolveScalePair(self, other);
|
const p = try resolveScalePair(self, alloc, other);
|
||||||
return @reduce(.And, p.l == p.r);
|
defer p.deinit(alloc);
|
||||||
|
return @reduce(.And, p.l.data.* == p.r.data.*);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// True iff any element differs after scale resolution.
|
/// True iff any element differs after scale resolution.
|
||||||
pub inline fn neAll(self: *const Self, other: anytype) bool {
|
pub inline fn neAll(self: *const Self, alloc: Allocator, other: anytype) !bool {
|
||||||
return !self.eqAll(other);
|
return !(try self.eqAll(alloc, other));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -272,21 +278,21 @@ test "TensorAlloc | Scalar comparisons (eq, ne, gt, gte, lt, lte)" {
|
|||||||
|
|
||||||
const m1000 = try Meter.splat(alloc, 1000);
|
const m1000 = try Meter.splat(alloc, 1000);
|
||||||
const km1 = try KiloMeter.splat(alloc, 1);
|
const km1 = try KiloMeter.splat(alloc, 1);
|
||||||
// const km2 = try KiloMeter.splat(alloc, 2);
|
const km2 = try KiloMeter.splat(alloc, 2);
|
||||||
|
|
||||||
try std.testing.expect(try m1000.eq(alloc, km1));
|
try std.testing.expect(try m1000.eq(alloc, km1));
|
||||||
try std.testing.expect(try km1.eq(alloc, m1000));
|
try std.testing.expect(try km1.eq(alloc, m1000));
|
||||||
// try std.testing.expect(km2.ne(m1000));
|
try std.testing.expect(try km2.ne(alloc, m1000));
|
||||||
//
|
|
||||||
// try std.testing.expect(km2.gt(m1000));
|
try std.testing.expect(try km2.gt(alloc, m1000));
|
||||||
// try std.testing.expect(km2.gt(km1));
|
try std.testing.expect(try km2.gt(alloc, km1));
|
||||||
// try std.testing.expect(km1.gte(m1000));
|
try std.testing.expect(try km1.gte(alloc, m1000));
|
||||||
// try std.testing.expect(km2.gte(m1000));
|
try std.testing.expect(try km2.gte(alloc, m1000));
|
||||||
//
|
|
||||||
// try std.testing.expect(m1000.lt(km2));
|
try std.testing.expect(try m1000.lt(alloc, km2));
|
||||||
// try std.testing.expect(km1.lt(km2));
|
try std.testing.expect(try km1.lt(alloc, km2));
|
||||||
// try std.testing.expect(km1.lte(m1000));
|
try std.testing.expect(try km1.lte(alloc, m1000));
|
||||||
// try std.testing.expect(m1000.lte(km2));
|
try std.testing.expect(try m1000.lte(alloc, km2));
|
||||||
}
|
}
|
||||||
|
|
||||||
// test "TensorAlloc | Scalar Add" {
|
// test "TensorAlloc | Scalar Add" {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user