From 5770dfa13060014a396b48098eae870af27e5c6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Wed, 18 Dec 2024 12:05:23 +0100 Subject: [PATCH] Update cebra/integrations/sklearn/utils.py Co-authored-by: Steffen Schneider --- cebra/integrations/sklearn/utils.py | 37 +++++++++-------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 4bee65fb..41803763 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -114,32 +114,17 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - if sklearn_version < version.parse("1.8"): - return sklearn_utils_validation.check_array( - y, - accept_sparse=False, - accept_large_sparse=False, - dtype="numeric", - order=None, - copy=False, - force_all_finite=True, - ensure_2d=False, - allow_nd=False, - ensure_min_samples=min_samples, - ) - else: - return sklearn_utils_validation.check_array( - y, - accept_sparse=False, - accept_large_sparse=False, - dtype="numeric", - order=None, - copy=False, - ensure_all_finite=True, - ensure_2d=False, - allow_nd=False, - ensure_min_samples=min_samples, - ) + return _check_array_ensure_all_finite( + y, + accept_sparse=False, + accept_large_sparse=False, + dtype="numeric", + order=None, + copy=False, + ensure_2d=False, + allow_nd=False, + ensure_min_samples=min_samples, + ) def check_device(device: str) -> str: