Skip to content

Commit

Permalink
heat.eq, heat.ne now allow non-array operands (#1773) (#1791)
Browse files Browse the repository at this point in the history
* changed eq and ne so that the input of wrong Types does not cause an error

* Changed eq and ne to include try except for wrong Types

* Changed tests to assert True/False instead of Errors

* fixed spelling of erroneous_type

---------

Co-authored-by: Claudia Comito <[email protected]>
(cherry picked from commit c282cb1)

Co-authored-by: Marc-Jindra <[email protected]>
  • Loading branch information
github-actions[bot] and Marc-Jindra authored Feb 12, 2025
1 parent 5ff62fb commit 7e6e3bc
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 45 deletions.
68 changes: 40 additions & 28 deletions heat/core/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@
]


def eq(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
def eq(x, y) -> DNDarray:
"""
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise comparision.
Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be
compared as argument.
Returns False if the operands are not scalars or :class:`~heat.core.dndarray.DNDarray`
Parameters
----------
Expand All @@ -57,21 +58,26 @@ def eq(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
>>> ht.eq(x, y)
DNDarray([[False, True],
[False, False]], dtype=ht.bool, device=cpu:0, split=None)
>>> ht.eq(x, slice(None))
False
"""
res = _operations.__binary_op(torch.eq, x, y)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)
try:
res = _operations.__binary_op(torch.eq, x, y)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res
return res
except (TypeError, ValueError):
return False


DNDarray.__eq__ = lambda self, other: eq(self, other)
Expand Down Expand Up @@ -372,11 +378,12 @@ def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
less.__doc__ = lt.__doc__


def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
def ne(x, y) -> DNDarray:
"""
Returns a :class:`~heat.core.dndarray.DNDarray` containing the results of element-wise rich comparison of non-equality between values from two operands, commutative.
Takes the first and second operand (scalar or :class:`~heat.core.dndarray.DNDarray`) whose elements are to be
compared as argument.
Returns True if the operands are not scalars or :class:`~heat.core.dndarray.DNDarray`
Parameters
----------
Expand All @@ -396,21 +403,26 @@ def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
>>> ht.ne(x, y)
DNDarray([[ True, False],
[ True, True]], dtype=ht.bool, device=cpu:0, split=None)
>>> ht.ne(x, slice(None))
True
"""
res = _operations.__binary_op(torch.ne, x, y)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)
try:
res = _operations.__binary_op(torch.ne, x, y)

if res.dtype != types.bool:
res = dndarray.DNDarray(
res.larray.type(torch.bool),
res.gshape,
types.bool,
res.split,
res.device,
res.comm,
res.balanced,
)

return res
return res
except (TypeError, ValueError):
return True


DNDarray.__ne__ = lambda self, other: ne(self, other)
Expand Down
28 changes: 11 additions & 17 deletions heat/core/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def setUpClass(cls):
cls.a_split_tensor = cls.another_tensor.copy().resplit_(0)
cls.split_ones_tensor = ht.ones((2, 2), split=1)

cls.errorneous_type = (2, 2)
cls.erroneous_type = (2, 2)

def test_eq(self):
result = ht.array([[False, True], [False, False]])
Expand All @@ -32,12 +32,9 @@ def test_eq(self):

self.assertEqual(ht.eq(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.eq(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.eq(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.eq("self.a_tensor", "s")
self.assertFalse(ht.eq(self.a_tensor, self.another_vector))
self.assertFalse(ht.eq(self.a_tensor, self.erroneous_type))
self.assertFalse(ht.eq("self.a_tensor", "s"))

def test_equal(self):
self.assertTrue(ht.equal(self.a_tensor, self.a_tensor))
Expand Down Expand Up @@ -78,7 +75,7 @@ def test_ge(self):
with self.assertRaises(ValueError):
ht.ge(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.ge(self.a_tensor, self.errorneous_type)
ht.ge(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.ge("self.a_tensor", "s")

Expand All @@ -99,7 +96,7 @@ def test_gt(self):
with self.assertRaises(ValueError):
ht.gt(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.gt(self.a_tensor, self.errorneous_type)
ht.gt(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.gt("self.a_tensor", "s")

Expand All @@ -120,7 +117,7 @@ def test_le(self):
with self.assertRaises(ValueError):
ht.le(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.le(self.a_tensor, self.errorneous_type)
ht.le(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.le("self.a_tensor", "s")

Expand All @@ -141,7 +138,7 @@ def test_lt(self):
with self.assertRaises(ValueError):
ht.lt(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.lt(self.a_tensor, self.errorneous_type)
ht.lt(self.a_tensor, self.erroneous_type)
with self.assertRaises(TypeError):
ht.lt("self.a_tensor", "s")

Expand All @@ -159,9 +156,6 @@ def test_ne(self):

self.assertEqual(ht.ne(self.a_split_tensor, self.a_tensor).dtype, ht.bool)

with self.assertRaises(ValueError):
ht.ne(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
ht.ne(self.a_tensor, self.errorneous_type)
with self.assertRaises(TypeError):
ht.ne("self.a_tensor", "s")
self.assertTrue(ht.ne(self.a_tensor, self.another_vector))
self.assertTrue(ht.ne(self.a_tensor, self.erroneous_type))
self.assertTrue(ht.ne("self.a_tensor", "s"))

0 comments on commit 7e6e3bc

Please sign in to comment.