From 3c05371fffa0692288b531738d849e7ab0ab64c8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 9 Jan 2025 10:11:04 +0100 Subject: [PATCH] fix: Do not compute Brier score when predict_proba is not available (#1064) closes #1050 As specified, do not try to compute Brier score with an estimator that does not provide probability estimate (i.e. does not have a `predict_proba` method). Added a non-regression test as well. --- .../cross_validation_helpers.py | 37 +++++------ .../tests/unit/sklearn/test_cross_validate.py | 61 +++++++++++++++++++ 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py index 97a46d61e..49d17452c 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_helpers.py @@ -33,28 +33,27 @@ def _get_scorers_to_add(estimator, y) -> dict[str, Any]: ), } if ml_task == "binary-classification": - return { - "roc_auc": "roc_auc", - "brier_score_loss": metrics.make_scorer( - metrics.brier_score_loss, response_method="predict_proba" - ), + scorers_to_add = { "recall": "recall", "precision": "precision", + "roc_auc": "roc_auc", } - if ml_task == "multiclass-classification": if hasattr(estimator, "predict_proba"): - return { - "recall_weighted": "recall_weighted", - "precision_weighted": "precision_weighted", - "roc_auc_ovr_weighted": "roc_auc_ovr_weighted", - "log_loss": metrics.make_scorer( - metrics.log_loss, response_method="predict_proba" - ), - } - return { + scorers_to_add["brier_score_loss"] = metrics.make_scorer( + metrics.brier_score_loss, response_method="predict_proba" + ) + return scorers_to_add + if ml_task == "multiclass-classification": + scorers_to_add = { "recall_weighted": "recall_weighted", "precision_weighted": "precision_weighted", } + if hasattr(estimator, "predict_proba"): + scorers_to_add["roc_auc_ovr_weighted"] = "roc_auc_ovr_weighted" + scorers_to_add["log_loss"] = metrics.make_scorer( + metrics.log_loss, response_method="predict_proba" + ) + return scorers_to_add return {} @@ -104,9 +103,11 @@ def _add_scorers(scorers, scorers_to_add): internal_scorer = _MultimetricScorer( scorers={ - name: check_scoring(estimator=None, scoring=scoring) - if isinstance(scoring, str) - else scoring + name: ( + check_scoring(estimator=None, scoring=scoring) + if isinstance(scoring, str) + else scoring + ) for name, scoring in scorers_to_add.items() } ) diff --git a/skore/tests/unit/sklearn/test_cross_validate.py b/skore/tests/unit/sklearn/test_cross_validate.py index 1c43b574e..6739d31fa 100644 --- a/skore/tests/unit/sklearn/test_cross_validate.py +++ b/skore/tests/unit/sklearn/test_cross_validate.py @@ -1,4 +1,9 @@ +import pytest +from sklearn.datasets import make_classification, make_regression +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.svm import SVC from skore.sklearn.cross_validation import CrossValidationReporter +from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add def prepare_cv(): @@ -35,3 +40,59 @@ def test_cross_validate_return_estimator(): assert "indices" in reporter.cv_results assert "estimator" in reporter._cv_results assert "indices" in reporter._cv_results + + +@pytest.mark.parametrize( + "estimator,dataset_func,dataset_kwargs,expected_keys", + [ + pytest.param( + LinearRegression(), + make_regression, + {"n_targets": 1}, + {"r2", "root_mean_squared_error"}, + id="regression", + ), + pytest.param( + LogisticRegression(), + make_classification, + {"n_classes": 2}, + {"recall", "precision", "roc_auc", "brier_score_loss"}, + id="binary_classification_with_proba", + ), + pytest.param( + SVC(probability=False), + make_classification, + {"n_classes": 2}, + {"recall", "precision", "roc_auc"}, + id="binary_classification_without_proba", + ), + pytest.param( + LogisticRegression(), + make_classification, + {"n_classes": 3, "n_clusters_per_class": 1}, + { + "recall_weighted", + "precision_weighted", + "roc_auc_ovr_weighted", + "log_loss", + }, + id="multiclass_with_proba", + ), + pytest.param( + SVC(probability=False), + make_classification, + {"n_classes": 3, "n_clusters_per_class": 1}, + {"recall_weighted", "precision_weighted"}, + id="multiclass_without_proba", + ), + ], +) +def test_get_scorers_to_add(estimator, dataset_func, dataset_kwargs, expected_keys): + """Check that the scorers to add are correct. + + Non-regression test for: + https://github.com/probabl-ai/skore/issues/1050 + """ + X, y = dataset_func(**dataset_kwargs) + scorers = _get_scorers_to_add(estimator, y) + assert set(scorers.keys()) == expected_keys