diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 455213a3..d9bb3083 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -22,12 +22,26 @@ import warnings import numpy.typing as npt +import packaging +import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch import cebra.helper +def _sklearn_check_array(array, **kwargs): + # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 + # https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html + # force_all_finite was renamed to ensure_all_finite and will be removed in 1.8. + if packaging.version.parse( + sklearn.__version__) < packaging.version.parse("1.6"): + if "ensure_all_finite" in kwargs: + kwargs["force_all_finite"] = kwargs["ensure_all_finite"] + del kwargs["ensure_all_finite"] + return sklearn_utils_validation.check_array(array, **kwargs) + + def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced. @@ -74,15 +88,15 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: Returns: The converted and validated array. """ - return sklearn_utils_validation.check_array( + return _sklearn_check_array( X, accept_sparse=False, accept_large_sparse=False, dtype=("float16", "float32", "float64"), order=None, copy=False, - force_all_finite=True, ensure_2d=True, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ensure_min_features=1, @@ -105,15 +119,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - return sklearn_utils_validation.check_array( + return _sklearn_check_array( y, accept_sparse=False, accept_large_sparse=False, dtype="numeric", order=None, copy=False, - force_all_finite=True, ensure_2d=False, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, )