From f2e6257823d625d317e011d7f32ea1f97c0ed867 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 16 Dec 2024 12:12:51 -0500 Subject: [PATCH] Handle batch size = None for goodness of fit computation --- cebra/integrations/sklearn/metrics.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 29f7715b..9a1dd5a6 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -138,7 +138,7 @@ def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA, >>> import cebra >>> import numpy as np >>> neural_data = np.random.uniform(0, 1, (1000, 20)) - >>> cebra_model = cebra.CEBRA(max_iterations=10) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) >>> cebra_model.fit(neural_data) CEBRA(max_iterations=10) >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data) @@ -169,7 +169,7 @@ def goodness_of_fit_history(model): >>> import cebra >>> import numpy as np >>> neural_data = np.random.uniform(0, 1, (1000, 20)) - >>> cebra_model = cebra.CEBRA(max_iterations=10) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) >>> cebra_model.fit(neural_data) CEBRA(max_iterations=10) >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model) @@ -210,6 +210,11 @@ def infonce_to_goodness_of_fit(infonce: Union[float, Iterable[float]], """ if not hasattr(model, "state_dict_"): raise RuntimeError("Fit the CEBRA model first.") + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). " + ) nats_to_bits = np.log2(np.e) num_sessions = model.num_sessions_