Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for new __sklearn_tags__ #205

Merged
merged 4 commits into from
Dec 16, 2024
Merged

Add support for new __sklearn_tags__ #205

merged 4 commits into from
Dec 16, 2024

Conversation

stes
Copy link
Member

@stes stes commented Dec 16, 2024

Fix #204

sklearn 1.6.0 was released on Dec 9, 24 and introduced a new mechanism for specifying estimator tags (https://scikit-learn.org/dev/developers/develop.html). This PR adopts CEBRA to comply with this new notation. Older sklearn variants will fall back to the more_tags() functions as recommended in this comment.

Indepedently, I spotted a bug in the inheritance order in the CEBRA class, which was fixed now, as described here.

Finally, since the code is now version dependent and there might be users rolling older sklearn version, I extended the test suite by one case checking with a legacy sklearn version (version 1.4.2 which is roughly one year old) -- this will hopefully cover the most important cases. The majority of tests are with sklearn latest (1.6.0 as of Dec 16, 24).

@stes stes requested a review from MMathisLab December 16, 2024 17:57
Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm; but did not directly test

@stes stes merged commit 5f46c32 into main Dec 16, 2024
13 checks passed
@stes stes deleted the stes/fix-sklearn-tags branch December 16, 2024 19:32
@stes stes mentioned this pull request Dec 20, 2024
2 tasks
@Gunnar-Stunnar
Copy link

Was this ever fixed, I am using Cebra v0.4.0 and now receiving this error:

[<ipython-input-13-2e6de4decd48>](https://localhost:8080/#) in train(self, neural_session, continous_sessions)
     44 
     45         # fit decoder
---> 46         emb = self.cebra_posOnly_model.transform(nStack_train)
     47         fullKin = trainAllKin(emb, cStack_train)
     48 

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/_set_output.py](https://localhost:8080/#) in wrapped(self, X, *args, **kwargs)
    317     @wraps(f)
    318     def wrapped(self, X, *args, **kwargs):
--> 319         data_to_wrap = f(self, X, *args, **kwargs)
    320         if isinstance(data_to_wrap, tuple):
    321             # only wrap the first output for cross decomposition

[/usr/local/lib/python3.10/dist-packages/cebra/integrations/sklearn/cebra.py](https://localhost:8080/#) in transform(self, X, session_id)
   1224         """
   1225 
-> 1226         sklearn_utils_validation.check_is_fitted(self, "n_features_")
   1227         model, offset = self._select_model(X, session_id)
   1228 

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/validation.py](https://localhost:8080/#) in check_is_fitted(estimator, attributes, msg, all_or_any)
   1749         raise TypeError("%s is not an estimator instance." % (estimator))
   1750 
-> 1751     tags = get_tags(estimator)
   1752 
   1753     if not tags.requires_fit and attributes is None:

[/usr/local/lib/python3.10/dist-packages/sklearn/utils/_tags.py](https://localhost:8080/#) in get_tags(estimator)
    403         for klass in reversed(type(estimator).mro()):
    404             if "__sklearn_tags__" in vars(klass):
--> 405                 sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator)  # type: ignore[attr-defined]
    406                 class_order.append(klass)
    407             elif "_more_tags" in vars(klass):

[/usr/local/lib/python3.10/dist-packages/sklearn/base.py](https://localhost:8080/#) in __sklearn_tags__(self)
    857 
    858     def __sklearn_tags__(self):
--> 859         tags = super().__sklearn_tags__()
    860         tags.transformer_tags = TransformerTags()
    861         return tags

AttributeError: 'super' object has no attribute '__sklearn_tags__'

This is after installing the latest Cebra package with the scikit-learn v1.6.0.

@Gunnar-Stunnar
Copy link

rolling back scikit-learn back to v1.5.2 worked

@stes
Copy link
Member Author

stes commented Dec 26, 2024

Hi @Gunnar-Stunnar , this was merged after the cebra 0.4.0 release. If you install the latest version from git,

pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git

the error should disappear even with sklearn > 1.6.0. In case you give that a try, please let me know if it works!

@Gunnar-Stunnar
Copy link

Looks like it worked!

My logs are now being filed with this error:

/usr/local/lib/python3.10/dist-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
pos: -9.4474 neg:  15.6202 total:  6.1728 temperature:  0.1000: 100%|██████████| 1000/1000 [00:16<00:00, 61.28it/s]
/usr/local/lib/python3.10/dist-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(

@juliagorman
Copy link

I recently installed CEBRA and am still gettitng the first error. I rolled back to scikit-learn back to v1.5.2 and now am getting this log error described above

@MMathisLab
Copy link
Member

You would need to pull from git, but a new version is coming soon!

@juliagorman
Copy link

if i re-install from git, will it change the PyTorch version I currently have installed in my conda environment? Should I just make a new environment to install from git?

@stes
Copy link
Member Author

stes commented Jan 29, 2025

I recently installed CEBRA and am still gettitng the first error. I rolled back to scikit-learn back to v1.5.2 and now am getting this log error described above

are you referring to this log output? This is just a warning message which is safe to ignore, there is no effect with respect to model fitting.

You can configure the warnings package if you want to get rid of the message.

Otherwise, as @MMathisLab noted, we will soon also release a new version of CEBRA properly handling this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Upstream sklearn change causes error ('super' object has no attribute '__sklearn_tags__') in test suite
4 participants