Skip to content

Commit

Permalink
fix: Do not compute Brier score when predict_proba is not available (#…
Browse files Browse the repository at this point in the history
…1064)

closes #1050 

As specified, do not try to compute Brier score with an estimator that
does not provide probability estimate (i.e. does not have a
`predict_proba` method).

Added a non-regression test as well.
  • Loading branch information
glemaitre authored Jan 9, 2025
1 parent 29d3689 commit 3c05371
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,27 @@ def _get_scorers_to_add(estimator, y) -> dict[str, Any]:
),
}
if ml_task == "binary-classification":
return {
"roc_auc": "roc_auc",
"brier_score_loss": metrics.make_scorer(
metrics.brier_score_loss, response_method="predict_proba"
),
scorers_to_add = {
"recall": "recall",
"precision": "precision",
"roc_auc": "roc_auc",
}
if ml_task == "multiclass-classification":
if hasattr(estimator, "predict_proba"):
return {
"recall_weighted": "recall_weighted",
"precision_weighted": "precision_weighted",
"roc_auc_ovr_weighted": "roc_auc_ovr_weighted",
"log_loss": metrics.make_scorer(
metrics.log_loss, response_method="predict_proba"
),
}
return {
scorers_to_add["brier_score_loss"] = metrics.make_scorer(
metrics.brier_score_loss, response_method="predict_proba"
)
return scorers_to_add
if ml_task == "multiclass-classification":
scorers_to_add = {
"recall_weighted": "recall_weighted",
"precision_weighted": "precision_weighted",
}
if hasattr(estimator, "predict_proba"):
scorers_to_add["roc_auc_ovr_weighted"] = "roc_auc_ovr_weighted"
scorers_to_add["log_loss"] = metrics.make_scorer(
metrics.log_loss, response_method="predict_proba"
)
return scorers_to_add
return {}


Expand Down Expand Up @@ -104,9 +103,11 @@ def _add_scorers(scorers, scorers_to_add):

internal_scorer = _MultimetricScorer(
scorers={
name: check_scoring(estimator=None, scoring=scoring)
if isinstance(scoring, str)
else scoring
name: (
check_scoring(estimator=None, scoring=scoring)
if isinstance(scoring, str)
else scoring
)
for name, scoring in scorers_to_add.items()
}
)
Expand Down
61 changes: 61 additions & 0 deletions skore/tests/unit/sklearn/test_cross_validate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import pytest
from sklearn.datasets import make_classification, make_regression
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.svm import SVC
from skore.sklearn.cross_validation import CrossValidationReporter
from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add


def prepare_cv():
Expand Down Expand Up @@ -35,3 +40,59 @@ def test_cross_validate_return_estimator():
assert "indices" in reporter.cv_results
assert "estimator" in reporter._cv_results
assert "indices" in reporter._cv_results


@pytest.mark.parametrize(
"estimator,dataset_func,dataset_kwargs,expected_keys",
[
pytest.param(
LinearRegression(),
make_regression,
{"n_targets": 1},
{"r2", "root_mean_squared_error"},
id="regression",
),
pytest.param(
LogisticRegression(),
make_classification,
{"n_classes": 2},
{"recall", "precision", "roc_auc", "brier_score_loss"},
id="binary_classification_with_proba",
),
pytest.param(
SVC(probability=False),
make_classification,
{"n_classes": 2},
{"recall", "precision", "roc_auc"},
id="binary_classification_without_proba",
),
pytest.param(
LogisticRegression(),
make_classification,
{"n_classes": 3, "n_clusters_per_class": 1},
{
"recall_weighted",
"precision_weighted",
"roc_auc_ovr_weighted",
"log_loss",
},
id="multiclass_with_proba",
),
pytest.param(
SVC(probability=False),
make_classification,
{"n_classes": 3, "n_clusters_per_class": 1},
{"recall_weighted", "precision_weighted"},
id="multiclass_without_proba",
),
],
)
def test_get_scorers_to_add(estimator, dataset_func, dataset_kwargs, expected_keys):
"""Check that the scorers to add are correct.
Non-regression test for:
https://github.com/probabl-ai/skore/issues/1050
"""
X, y = dataset_func(**dataset_kwargs)
scorers = _get_scorers_to_add(estimator, y)
assert set(scorers.keys()) == expected_keys

0 comments on commit 3c05371

Please sign in to comment.