From 297ee92476db9847f400834c04114a01fddc011d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:42:44 +0100 Subject: [PATCH] Add goodness of fit metric --- cebra/integrations/sklearn/metrics.py | 62 +++++++++++++++++++++++++++ tests/test_sklearn_metrics.py | 51 ++++++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 9712d021..b9be33ee 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -108,6 +108,68 @@ def infonce_loss( return avg_loss +def goodness_of_fit(model: cebra_sklearn_cebra.CEBRA) -> List[float]: + """Evaluate the goodness of fit (bits) for a given model. + + This function calculates the goodness of fit for the provided model + using the specified batch size. The goodness of fit is computed offline + it is a way to normalize wrt batch size to compare models with + different batch sizes or different implementations. + + Args: + model: The model to evaluate. This can be an instance of either + `cebra_sklearn_cebra.CEBRA` or `cebra_solver.Solver`. + batch_size: Batch size used to train the model. + + Returns: + A list of float values representing the goodness of fit for the model. + """ + + if isinstance(model, cebra_sklearn_cebra.CEBRA): + if model.batch_size is None: + raise NotImplementedError( + "Batch size is None, please provide a model with a batch size to compute the goodness of fit." + ) + if model.solver_name_ == 'single-session': + gof = _goodness_of_fit(loss=model.state_dict_["loss"], + batch_size=model.batch_size) + elif model.solver_name_ == 'multi-session': + # For the multisession implementation, the batch size is multiplied by the + # number of datasets to get the correct comparison. + gof = _goodness_of_fit(loss=model.state_dict_["loss"], + batch_size=model.batch_size * + model.num_sessions_) + else: + raise NotImplementedError(f"Invalid solver: {model.solver_name_}.") + elif isinstance(model, list): + raise ValueError( + f"Model should correspond to a single CEBRA model," + f"got {type(model)}, containing {len(model)} elements.") + else: + raise ValueError(f"Provide CEBRA model, got {type(model)}.") + return gof + + +def _goodness_of_fit(loss: List[float], batch_size: int) -> List[float]: + """ + Compute offline the goodness of fit (bits) from a provided loss. + + This is a way to normalize wrt batch size to compare models with + different batch sizes or different implementations. + + Args: + loss: A list of size `max_iteration`, corresponding to the loss across training. + batch_size: Batch size used to train the model. For multisession implementation, + you need to multiply the batch size by the number of datasets to get the correct + comparison. + + Returns: + A list of float corresponding to the goodness of fit for the provided loss and batch size. + """ + log_batch_size = np.log(batch_size) + return [(1 / np.log(2)) * (log_batch_size - lb) for lb in loss] + + def _consistency_scores( embeddings: List[Union[npt.NDArray, torch.Tensor]], datasets: List[Union[int, str]], diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 58e12010..f46597ed 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -223,6 +223,57 @@ def test_sklearn_infonce_loss(): ) +def test_sklearn_goodness_of_fit(): + max_loss_iterations = 2 + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset10-model", + max_iterations=5, + batch_size=128, + ) + + # Example data + X = torch.tensor(np.random.uniform(0, 1, (1000, 50))) + y_c1 = torch.tensor(np.random.uniform(0, 1, (1000, 5))) + + X2 = torch.tensor(np.random.uniform(0, 1, (500, 20))) + y2_c1 = torch.tensor(np.random.uniform(0, 1, (500, 5))) + + # Single session + cebra_model.fit(X, y_c1) + + gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model) + assert isinstance(gof, list) + _gof = cebra.sklearn.metrics._goodness_of_fit( + cebra_model.state_dict_["loss"], batch_size=128) + assert isinstance(_gof, list) + assert gof == _gof + + # Multisession + cebra_model.fit([X, X2], [y_c1, y2_c1]) + + gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model) + assert isinstance(gof, list) + _gof = cebra.sklearn.metrics._goodness_of_fit( + cebra_model.state_dict_["loss"], batch_size=128 * 2) + assert isinstance(_gof, list) + assert gof == _gof + + # Multiple models passed + with pytest.raises(ValueError, match="single.*model"): + _ = cebra.sklearn.metrics.goodness_of_fit([cebra_model, cebra_model]) + + # No batch size + cebra_model_no_bs = cebra_sklearn_cebra.CEBRA( + model_architecture="offset10-model", + max_iterations=max_loss_iterations, + batch_size=None, + ) + + cebra_model_no_bs.fit(X) + with pytest.raises(NotImplementedError, match="Batch.*size"): + gof = cebra.sklearn.metrics.goodness_of_fit(cebra_model_no_bs) + + def test_sklearn_datasets_consistency(): # Example data np.random.seed(42)