diff --git a/modin/numpy/__init__.py b/modin/numpy/__init__.py index 765fc615786..4cf4f23f984 100644 --- a/modin/numpy/__init__.py +++ b/modin/numpy/__init__.py @@ -11,27 +11,28 @@ # ANY KIND, either express or implied. See the License for the specific language # governing permissions and limitations under the License. +import numpy +from packaging import version + from . import linalg from .arr import array from .array_creation import ones_like, tri, zeros_like from .array_shaping import append, hstack, ravel, shape, split, transpose -from .constants import ( - NAN, - NINF, - NZERO, - PINF, - PZERO, - Inf, - Infinity, - NaN, - e, - euler_gamma, - inf, - infty, - nan, - newaxis, - pi, -) +from .constants import e, euler_gamma, inf, nan, newaxis, pi + +if version.parse(numpy.__version__) < version.parse("2.0.0b1"): + from .constants import ( + NAN, + NINF, + NZERO, + PINF, + PZERO, + Inf, + Infinity, + NaN, + infty, + ) + from .logic import ( all, any, @@ -151,18 +152,9 @@ def where(condition, x=None, y=None): "amin", "min", "where", - "Inf", - "Infinity", - "NAN", - "NINF", - "NZERO", - "NaN", - "PINF", - "PZERO", "e", "euler_gamma", "inf", - "infty", "nan", "newaxis", "pi", @@ -177,3 +169,15 @@ def where(condition, x=None, y=None): "append", "tri", ] +if version.parse(numpy.__version__) < version.parse("2.0.0b1"): + __all__ += [ + "Inf", + "Infinity", + "NAN", + "NINF", + "NZERO", + "NaN", + "PINF", + "PZERO", + "infty", + ] diff --git a/modin/numpy/constants.py b/modin/numpy/constants.py index 6070901e5f6..5f103f66d77 100644 --- a/modin/numpy/constants.py +++ b/modin/numpy/constants.py @@ -11,39 +11,31 @@ # ANY KIND, either express or implied. See the License for the specific language # governing permissions and limitations under the License. -# flake8: noqa -from numpy import ( - NAN, - NINF, - NZERO, - PINF, - PZERO, - Inf, - Infinity, - NaN, - e, - euler_gamma, - inf, - infty, - nan, - newaxis, - pi, -) +import numpy +from numpy import e, euler_gamma, inf, nan, newaxis, pi +from packaging import version + +if version.parse(numpy.__version__) < version.parse("2.0.0b1"): + from numpy import NAN, NINF, NZERO, PINF, PZERO, Inf, Infinity, NaN, infty __all__ = [ - "Inf", - "Infinity", - "NAN", - "NINF", - "NZERO", - "NaN", - "PINF", - "PZERO", "e", "euler_gamma", "inf", - "infty", "nan", "newaxis", "pi", ] + +if version.parse(numpy.__version__) < version.parse("2.0.0b1"): + __all__ += [ + "Inf", + "Infinity", + "NAN", + "NINF", + "NZERO", + "NaN", + "PINF", + "PZERO", + "infty", + ]