From aaf912f88aad469fde0063997dc4bfa661f725a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Tue, 17 Dec 2024 11:14:53 +0100 Subject: [PATCH 01/11] Fix deprecation warning force_all_finite -> ensure_all_finite --- cebra/integrations/sklearn/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 455213a3..dec7db3d 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -81,7 +81,7 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: dtype=("float16", "float32", "float64"), order=None, copy=False, - force_all_finite=True, + ensure_all_finite=True, ensure_2d=True, allow_nd=False, ensure_min_samples=min_samples, @@ -112,7 +112,7 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): dtype="numeric", order=None, copy=False, - force_all_finite=True, + ensure_all_finite=True, ensure_2d=False, allow_nd=False, ensure_min_samples=min_samples, From 2ca1163bb10c76f74d062e5f81b089cfc1a27371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Tue, 17 Dec 2024 12:02:58 +0100 Subject: [PATCH 02/11] Added version compatibility for sklearn 1.8+ --- cebra/integrations/sklearn/utils.py | 83 ++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 25 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index dec7db3d..f7b03b81 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -27,6 +27,9 @@ import cebra.helper +from packaging import version +from sklearn import __version__ as sklearn_version + def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced. @@ -74,19 +77,35 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: Returns: The converted and validated array. """ - return sklearn_utils_validation.check_array( - X, - accept_sparse=False, - accept_large_sparse=False, - dtype=("float16", "float32", "float64"), - order=None, - copy=False, - ensure_all_finite=True, - ensure_2d=True, - allow_nd=False, - ensure_min_samples=min_samples, - ensure_min_features=1, - ) + + if sklearn_version < version.parse("1.8"): + return sklearn_utils_validation.check_array( + X, + accept_sparse=False, + accept_large_sparse=False, + dtype=("float16", "float32", "float64"), + order=None, + copy=False, + force_all_finite=True, + ensure_2d=True, + allow_nd=False, + ensure_min_samples=min_samples, + ensure_min_features=1, + ) + else: + return sklearn_utils_validation.check_array( + X, + accept_sparse=False, + accept_large_sparse=False, + dtype=("float16", "float32", "float64"), + order=None, + copy=False, + ensure_all_finite=True, + ensure_2d=True, + allow_nd=False, + ensure_min_samples=min_samples, + ensure_min_features=1, + ) def check_label_array(y: npt.NDArray, *, min_samples: int): @@ -105,18 +124,32 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - return sklearn_utils_validation.check_array( - y, - accept_sparse=False, - accept_large_sparse=False, - dtype="numeric", - order=None, - copy=False, - ensure_all_finite=True, - ensure_2d=False, - allow_nd=False, - ensure_min_samples=min_samples, - ) + if sklearn_version < version.parse("1.8"): + return sklearn_utils_validation.check_array( + y, + accept_sparse=False, + accept_large_sparse=False, + dtype="numeric", + order=None, + copy=False, + force_all_finite=True, + ensure_2d=False, + allow_nd=False, + ensure_min_samples=min_samples, + ) + else: + return sklearn_utils_validation.check_array( + y, + accept_sparse=False, + accept_large_sparse=False, + dtype="numeric", + order=None, + copy=False, + ensure_all_finite=True, + ensure_2d=False, + allow_nd=False, + ensure_min_samples=min_samples, + ) def check_device(device: str) -> str: From 3e4ee86835adaa005ccaf4288f8e829fc5b0006f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Tue, 17 Dec 2024 12:06:35 +0100 Subject: [PATCH 03/11] checkpoint --- cebra/integrations/sklearn/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index f7b03b81..c1671b9e 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -29,6 +29,7 @@ from packaging import version from sklearn import __version__ as sklearn_version +sklearn_version = version.parse(sklearn_version) def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: From 3ba6bc6ac617be0a7d0846c6ad42154effc0c983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Wed, 18 Dec 2024 12:04:49 +0100 Subject: [PATCH 04/11] Update cebra/integrations/sklearn/utils.py Co-authored-by: Steffen Schneider --- cebra/integrations/sklearn/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index c1671b9e..1f6c621f 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -28,9 +28,15 @@ import cebra.helper from packaging import version -from sklearn import __version__ as sklearn_version -sklearn_version = version.parse(sklearn_version) +import sklearn +def _check_array_ensure_all_finite(array, **kwargs): + if version.parse(sklearn.__version__) < version.parse("1.8"): + key = "force_all_finite" + else: + key = "ensure_all_finite" + kwargs[key] = True + return sklearn_utils_validation.check_array(array, **kwargs) def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced. From 128257bf203adabfffc8f3626bf0a921b46fc0a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Wed, 18 Dec 2024 12:05:13 +0100 Subject: [PATCH 05/11] Update cebra/integrations/sklearn/utils.py Co-authored-by: Steffen Schneider --- cebra/integrations/sklearn/utils.py | 41 +++++++++-------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 1f6c621f..4bee65fb 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -84,35 +84,18 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: Returns: The converted and validated array. """ - - if sklearn_version < version.parse("1.8"): - return sklearn_utils_validation.check_array( - X, - accept_sparse=False, - accept_large_sparse=False, - dtype=("float16", "float32", "float64"), - order=None, - copy=False, - force_all_finite=True, - ensure_2d=True, - allow_nd=False, - ensure_min_samples=min_samples, - ensure_min_features=1, - ) - else: - return sklearn_utils_validation.check_array( - X, - accept_sparse=False, - accept_large_sparse=False, - dtype=("float16", "float32", "float64"), - order=None, - copy=False, - ensure_all_finite=True, - ensure_2d=True, - allow_nd=False, - ensure_min_samples=min_samples, - ensure_min_features=1, - ) + return _check_array_ensure_all_finite( + X, + accept_sparse=False, + accept_large_sparse=False, + dtype=("float16", "float32", "float64"), + order=None, + copy=False, + ensure_2d=True, + allow_nd=False, + ensure_min_samples=min_samples, + ensure_min_features=1, + ) def check_label_array(y: npt.NDArray, *, min_samples: int): From 5770dfa13060014a396b48098eae870af27e5c6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Wed, 18 Dec 2024 12:05:23 +0100 Subject: [PATCH 06/11] Update cebra/integrations/sklearn/utils.py Co-authored-by: Steffen Schneider --- cebra/integrations/sklearn/utils.py | 37 +++++++++-------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 4bee65fb..41803763 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -114,32 +114,17 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - if sklearn_version < version.parse("1.8"): - return sklearn_utils_validation.check_array( - y, - accept_sparse=False, - accept_large_sparse=False, - dtype="numeric", - order=None, - copy=False, - force_all_finite=True, - ensure_2d=False, - allow_nd=False, - ensure_min_samples=min_samples, - ) - else: - return sklearn_utils_validation.check_array( - y, - accept_sparse=False, - accept_large_sparse=False, - dtype="numeric", - order=None, - copy=False, - ensure_all_finite=True, - ensure_2d=False, - allow_nd=False, - ensure_min_samples=min_samples, - ) + return _check_array_ensure_all_finite( + y, + accept_sparse=False, + accept_large_sparse=False, + dtype="numeric", + order=None, + copy=False, + ensure_2d=False, + allow_nd=False, + ensure_min_samples=min_samples, + ) def check_device(device: str) -> str: From 3697d1aa39d9c6366e8b64686b4c00d711a4d3df Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 21 Jan 2025 22:38:35 +0100 Subject: [PATCH 07/11] Source formatting --- cebra/integrations/sklearn/utils.py | 31 ++++++++++++++++------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 41803763..bb7d38d6 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -22,22 +22,25 @@ import warnings import numpy.typing as npt +import packaging +import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch import cebra.helper -from packaging import version -import sklearn def _check_array_ensure_all_finite(array, **kwargs): - if version.parse(sklearn.__version__) < version.parse("1.8"): + # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 + if packaging.version.parse( + sklearn.__version__) < packaging.version.parse("1.8"): key = "force_all_finite" else: key = "ensure_all_finite" kwargs[key] = True return sklearn_utils_validation.check_array(array, **kwargs) + def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced. @@ -85,17 +88,17 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: The converted and validated array. """ return _check_array_ensure_all_finite( - X, - accept_sparse=False, - accept_large_sparse=False, - dtype=("float16", "float32", "float64"), - order=None, - copy=False, - ensure_2d=True, - allow_nd=False, - ensure_min_samples=min_samples, - ensure_min_features=1, - ) + X, + accept_sparse=False, + accept_large_sparse=False, + dtype=("float16", "float32", "float64"), + order=None, + copy=False, + ensure_2d=True, + allow_nd=False, + ensure_min_samples=min_samples, + ensure_min_features=1, + ) def check_label_array(y: npt.NDArray, *, min_samples: int): From fb0593df272e51af3096ba1c4d0572558374a0e8 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 21 Jan 2025 22:41:27 +0100 Subject: [PATCH 08/11] improve backwards compatibility implementation --- cebra/integrations/sklearn/utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index bb7d38d6..d787cdbe 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -30,14 +30,13 @@ import cebra.helper -def _check_array_ensure_all_finite(array, **kwargs): +def _sklearn_check_array(array, **kwargs): # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 if packaging.version.parse( sklearn.__version__) < packaging.version.parse("1.8"): - key = "force_all_finite" - else: - key = "ensure_all_finite" - kwargs[key] = True + if "ensure_all_finite" in kwargs: + kwargs["force_all_finite"] = kwargs["ensure_all_finite"] + del kwargs["ensure_all_finite"] return sklearn_utils_validation.check_array(array, **kwargs) @@ -87,7 +86,7 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: Returns: The converted and validated array. """ - return _check_array_ensure_all_finite( + return _sklearn_check_array( X, accept_sparse=False, accept_large_sparse=False, @@ -95,6 +94,7 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: order=None, copy=False, ensure_2d=True, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ensure_min_features=1, @@ -117,7 +117,7 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - return _check_array_ensure_all_finite( + return _sklearn_check_array( y, accept_sparse=False, accept_large_sparse=False, @@ -125,6 +125,7 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): order=None, copy=False, ensure_2d=False, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ) From 63ae99a7725c1aa1c2ebf084147c7c4fa5c89138 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 21 Jan 2025 22:49:28 +0100 Subject: [PATCH 09/11] update workflows to actions/setup-python v4 --- .github/workflows/doc-coverage.yml | 2 +- .github/workflows/docs.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/doc-coverage.yml b/.github/workflows/doc-coverage.yml index 268cbee0..8c4a7715 100644 --- a/.github/workflows/doc-coverage.yml +++ b/.github/workflows/doc-coverage.yml @@ -38,7 +38,7 @@ jobs: restore-keys: | ${{ runner.os }}-pip - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install package diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 83c9d829..686b274a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -52,7 +52,7 @@ jobs: ref: main - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -60,12 +60,12 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel # NOTE(stes) Pandoc version must be at least (2.14.2) but less than (4.0.0). - # as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an + # as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an # old pandoc version (2.9.). We will hence install the latest version manually. # previou: sudo apt-get install -y pandoc - wget https://github.com/jgm/pandoc/releases/download/3.1.9/pandoc-3.1.9-1-amd64.deb - sudo dpkg -i pandoc-3.1.9-1-amd64.deb - rm pandoc-3.1.9-1-amd64.deb + wget https://github.com/jgm/pandoc/releases/download/3.1.9/pandoc-3.1.9-1-amd64.deb + sudo dpkg -i pandoc-3.1.9-1-amd64.deb + rm pandoc-3.1.9-1-amd64.deb pip install torch --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[docs]' From 5f3108cd69c4149d5e74ca79165907f69fde7633 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 21 Jan 2025 22:53:19 +0100 Subject: [PATCH 10/11] upgrade to v5 and python 3.9 --- .github/workflows/build.yml | 6 +++--- .github/workflows/doc-coverage.yml | 4 ++-- .github/workflows/docs.yml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ef9e1777..52cb87f5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: torch-version: 2.4.0 python-version: "3.10" sklearn-version: "latest" - - os: ubuntu-latest + - os: ubuntu-latest torch-version: 2.4.0 python-version: "3.10" sklearn-version: "legacy" @@ -44,7 +44,7 @@ jobs: uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -54,7 +54,7 @@ jobs: python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[dev,datasets,integrations]' - - name: Check sklearn legacy version + - name: Check sklearn legacy version if: matrix.sklearn-version == 'legacy' run: | pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' diff --git a/.github/workflows/doc-coverage.yml b/.github/workflows/doc-coverage.yml index 8c4a7715..4c5e3f45 100644 --- a/.github/workflows/doc-coverage.yml +++ b/.github/workflows/doc-coverage.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8'] + python-version: ['3.9'] steps: # NOTE(stes) currently not used, we check @@ -38,7 +38,7 @@ jobs: restore-keys: | ${{ runner.os }}-pip - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install package diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 686b274a..79dfdfa7 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -52,7 +52,7 @@ jobs: ref: main - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} From fb431e118cdbb094ed0f705f7df79ad7543dc039 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 22 Jan 2025 00:54:40 +0100 Subject: [PATCH 11/11] Update cebra/integrations/sklearn/utils.py --- cebra/integrations/sklearn/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index d787cdbe..d9bb3083 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -32,8 +32,10 @@ def _sklearn_check_array(array, **kwargs): # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 + # https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html + # force_all_finite was renamed to ensure_all_finite and will be removed in 1.8. if packaging.version.parse( - sklearn.__version__) < packaging.version.parse("1.8"): + sklearn.__version__) < packaging.version.parse("1.6"): if "ensure_all_finite" in kwargs: kwargs["force_all_finite"] = kwargs["ensure_all_finite"] del kwargs["ensure_all_finite"]