diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a231258f..ef9e1777 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,10 +19,16 @@ jobs: # as well as selected previous versions on # https://pytorch.org/get-started/previous-versions/ torch-version: ["2.2.2", "2.4.0"] + sklearn-version: ["latest"] include: - os: windows-latest torch-version: 2.4.0 python-version: "3.10" + sklearn-version: "latest" + - os: ubuntu-latest + torch-version: 2.4.0 + python-version: "3.10" + sklearn-version: "legacy" runs-on: ${{ matrix.os }} @@ -32,7 +38,7 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }} + key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }} - name: Checkout code uses: actions/checkout@v2 @@ -48,6 +54,11 @@ jobs: python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[dev,datasets,integrations]' + - name: Check sklearn legacy version + if: matrix.sklearn-version == 'legacy' + run: | + pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 046d3344..9a74eeb6 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -30,8 +30,10 @@ import pkg_resources import sklearn.utils.validation as sklearn_utils_validation import torch +import sklearn from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin +from sklearn.utils.metaestimators import available_if from torch import nn import cebra.data @@ -41,6 +43,11 @@ import cebra.models import cebra.solver +def check_version(estimator): + # NOTE(stes): required as a check for the old way of specifying tags + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + from packaging import version + return version.parse(sklearn.__version__) < version.parse("1.6.dev") def _init_loader( is_cont: bool, @@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": return cebra_ -class CEBRA(BaseEstimator, TransformerMixin): +class CEBRA(TransformerMixin, BaseEstimator): """CEBRA model defined as part of a ``scikit-learn``-like API. Attributes: @@ -1294,6 +1301,15 @@ def fit_transform( callback_frequency=callback_frequency) return self.transform(X) + def __sklearn_tags__(self): + # NOTE(stes): from 1.6.dev, this is the new way to specify tags + # https://scikit-learn.org/dev/developers/develop.html + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + tags = super().__sklearn_tags__() + tags.non_deterministic = True + return tags + + @available_if(check_version) def _more_tags(self): # NOTE(stes): This tag is needed as seeding is not fully implemented in the # current version of CEBRA.