From 77b5c856d63a35fe1b3cfacbe5d9d59c6561e6a6 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 16 Dec 2024 09:36:30 -0800 Subject: [PATCH 1/4] Add support for new __sklearn_tags__ --- cebra/integrations/sklearn/cebra.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 046d3344..5373a7c2 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -30,8 +30,10 @@ import pkg_resources import sklearn.utils.validation as sklearn_utils_validation import torch +import sklearn from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin +from sklearn.utils.metaestimators import available_if from torch import nn import cebra.data @@ -41,6 +43,11 @@ import cebra.models import cebra.solver +def check_version(estimator): + # NOTE(stes): required as a check for the old way of specifying tags + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + from packaging import version + return version.parse(sklearn.__version__) < version.parse("1.6.dev") def _init_loader( is_cont: bool, @@ -1294,6 +1301,15 @@ def fit_transform( callback_frequency=callback_frequency) return self.transform(X) + def __sklearn_tags__(self): + # NOTE(stes): from 1.6.dev, this is the new way to specify tags + # https://scikit-learn.org/dev/developers/develop.html + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + tags = super().__sklearn_tags__() + tags.non_deterministic = True + return tags + + @available_if(check_version) def _more_tags(self): # NOTE(stes): This tag is needed as seeding is not fully implemented in the # current version of CEBRA. From 80f1ea9685ce46f4d457eeacf81f1ec9721f8a06 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 16 Dec 2024 09:46:59 -0800 Subject: [PATCH 2/4] fix inheritance order --- cebra/integrations/sklearn/cebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 5373a7c2..9a74eeb6 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -371,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": return cebra_ -class CEBRA(BaseEstimator, TransformerMixin): +class CEBRA(TransformerMixin, BaseEstimator): """CEBRA model defined as part of a ``scikit-learn``-like API. Attributes: From cb8b5da690e307e7a667b3d0b1ab592e9445c3c7 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 16 Dec 2024 10:04:34 -0800 Subject: [PATCH 3/4] Add more tests --- .github/workflows/build.yml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a231258f..6ccfad5e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,10 +19,17 @@ jobs: # as well as selected previous versions on # https://pytorch.org/get-started/previous-versions/ torch-version: ["2.2.2", "2.4.0"] + sklearn-version: ["latest"] include: - os: windows-latest torch-version: 2.4.0 python-version: "3.10" + sklearn-version: "latest" + include: + - os: ubuntu-latest + torch-version: 2.4.0 + python-version: "3.10" + sklearn-version: "legacy" runs-on: ${{ matrix.os }} @@ -32,7 +39,7 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }} + key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }} - name: Checkout code uses: actions/checkout@v2 @@ -48,6 +55,11 @@ jobs: python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[dev,datasets,integrations]' + - name: Check sklearn legacy version + if: matrix.sklearn-version == 'legacy' + run: | + pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format From bd5407fa61dca842bda5a74e9766334504abfb08 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 16 Dec 2024 10:08:33 -0800 Subject: [PATCH 4/4] fix added test --- .github/workflows/build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6ccfad5e..ef9e1777 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,6 @@ jobs: torch-version: 2.4.0 python-version: "3.10" sklearn-version: "latest" - include: - os: ubuntu-latest torch-version: 2.4.0 python-version: "3.10"