Skip to content

Commit

Permalink
Fix arange with inf step (ml-explore#686)
Browse files Browse the repository at this point in the history
* Fix case for step=inf in arange and add inf check for start/stop

* Add test cases for arange

* Update ops.cpp to include climits header

* Fix arange

* Fix formatting

* Refactor

* Add missing include
  • Loading branch information
noahfarr authored Feb 23, 2024
1 parent 126c986 commit d729a19
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
19 changes: 17 additions & 2 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2023-2024 Apple Inc.

#include <algorithm>
#include <climits>
#include <cmath>
#include <numeric>
#include <set>
Expand Down Expand Up @@ -73,10 +74,24 @@ array arange(
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
double real_size = std::ceil((stop - start) / step);
if (std::isnan(real_size)) {

if (std::isinf(start) || std::isinf(stop)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}

// Check if start and stop specify a valid range because if not, we have to
// return an empty array
if (std::isinf(step) &&
(step > 0 && start < stop || step < 0 && start > stop)) {
return array({start}, dtype);
}

double real_size = std::ceil((stop - start) / step);

if (real_size > INT_MAX) {
throw std::invalid_argument("[arange] Maximum size exceeded.");
}

int size = std::max(static_cast<int>(real_size), 0);
return array(
{size},
Expand Down
26 changes: 26 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,11 @@ def test_arange_overload_dispatch(self):
a = mx.arange(0, float("inf"), float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, 5)
with self.assertRaises(ValueError):
INT_MAX = 2147483647
a = mx.arange(0, INT_MAX + 1, 1)

a = mx.arange(5)
expected = [0, 1, 2, 3, 4]
Expand Down Expand Up @@ -1132,6 +1137,27 @@ def test_arange_corner_cases_cast(self):
self.assertListEqual(a.tolist(), expected)
self.assertEqual(a.dtype, mx.int32)

a = mx.arange(0, 10, 100)
expected = [0]
self.assertListEqual(a.tolist(), expected)
self.assertEqual(a.dtype, mx.int32)

a = mx.arange(10, 0, 1)
expected = []
self.assertListEqual(a.tolist(), expected)

a = mx.arange(10, 0, float("inf"))
expected = []
self.assertListEqual(a.tolist(), expected)

a = mx.arange(0, 10, float("inf"))
expected = [0]
self.assertListEqual(a.tolist(), expected)

a = mx.arange(0, -10, float("-inf"))
expected = [0]
self.assertListEqual(a.tolist(), expected)

def test_unary_ops(self):
def test_ops(npop, mlxop, x, y, atol):
r_np = npop(x)
Expand Down

0 comments on commit d729a19

Please sign in to comment.