From d729a1991bd7c35e1e44af0c791e4b0a222b63af Mon Sep 17 00:00:00 2001 From: Noah Farr <69793313+noahfarr@users.noreply.github.com> Date: Fri, 23 Feb 2024 15:18:15 +0100 Subject: [PATCH] Fix arange with inf step (#686) * Fix case for step=inf in arange and add inf check for start/stop * Add test cases for arange * Update ops.cpp to include climits header * Fix arange * Fix formatting * Refactor * Add missing include --- mlx/ops.cpp | 19 +++++++++++++++++-- python/tests/test_ops.py | 26 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a1d4c5b4e8..2c3cf55abf 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1,6 +1,7 @@ // Copyright © 2023-2024 Apple Inc. #include +#include #include #include #include @@ -73,10 +74,24 @@ array arange( if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) { throw std::invalid_argument("[arange] Cannot compute length."); } - double real_size = std::ceil((stop - start) / step); - if (std::isnan(real_size)) { + + if (std::isinf(start) || std::isinf(stop)) { throw std::invalid_argument("[arange] Cannot compute length."); } + + // Check if start and stop specify a valid range because if not, we have to + // return an empty array + if (std::isinf(step) && + (step > 0 && start < stop || step < 0 && start > stop)) { + return array({start}, dtype); + } + + double real_size = std::ceil((stop - start) / step); + + if (real_size > INT_MAX) { + throw std::invalid_argument("[arange] Maximum size exceeded."); + } + int size = std::max(static_cast(real_size), 0); return array( {size}, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3401338f8f..23a8c7bc11 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1047,6 +1047,11 @@ def test_arange_overload_dispatch(self): a = mx.arange(0, float("inf"), float("inf")) with self.assertRaises(ValueError): a = mx.arange(float("inf"), 1, float("inf")) + with self.assertRaises(ValueError): + a = mx.arange(float("inf"), 1, 5) + with self.assertRaises(ValueError): + INT_MAX = 2147483647 + a = mx.arange(0, INT_MAX + 1, 1) a = mx.arange(5) expected = [0, 1, 2, 3, 4] @@ -1132,6 +1137,27 @@ def test_arange_corner_cases_cast(self): self.assertListEqual(a.tolist(), expected) self.assertEqual(a.dtype, mx.int32) + a = mx.arange(0, 10, 100) + expected = [0] + self.assertListEqual(a.tolist(), expected) + self.assertEqual(a.dtype, mx.int32) + + a = mx.arange(10, 0, 1) + expected = [] + self.assertListEqual(a.tolist(), expected) + + a = mx.arange(10, 0, float("inf")) + expected = [] + self.assertListEqual(a.tolist(), expected) + + a = mx.arange(0, 10, float("inf")) + expected = [0] + self.assertListEqual(a.tolist(), expected) + + a = mx.arange(0, -10, float("-inf")) + expected = [0] + self.assertListEqual(a.tolist(), expected) + def test_unary_ops(self): def test_ops(npop, mlxop, x, y, atol): r_np = npop(x)