Skip to content

Commit

Permalink
Update cebra/integrations/sklearn/utils.py
Browse files Browse the repository at this point in the history
Co-authored-by: Steffen Schneider <[email protected]>
  • Loading branch information
icarosadero and stes authored Dec 18, 2024
1 parent 128257b commit 5770dfa
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,17 @@ def check_label_array(y: npt.NDArray, *, min_samples: int):
Returns:
The converted and validated labels.
"""
if sklearn_version < version.parse("1.8"):
return sklearn_utils_validation.check_array(
y,
accept_sparse=False,
accept_large_sparse=False,
dtype="numeric",
order=None,
copy=False,
force_all_finite=True,
ensure_2d=False,
allow_nd=False,
ensure_min_samples=min_samples,
)
else:
return sklearn_utils_validation.check_array(
y,
accept_sparse=False,
accept_large_sparse=False,
dtype="numeric",
order=None,
copy=False,
ensure_all_finite=True,
ensure_2d=False,
allow_nd=False,
ensure_min_samples=min_samples,
)
return _check_array_ensure_all_finite(
y,
accept_sparse=False,
accept_large_sparse=False,
dtype="numeric",
order=None,
copy=False,
ensure_2d=False,
allow_nd=False,
ensure_min_samples=min_samples,
)


def check_device(device: str) -> str:
Expand Down

0 comments on commit 5770dfa

Please sign in to comment.