Skip to content

Commit

Permalink
more refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Jan 8, 2025
1 parent 11d338e commit e46ad58
Showing 1 changed file with 27 additions and 80 deletions.
107 changes: 27 additions & 80 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,23 +534,8 @@ def _normalize_metric_name(column):
return re.sub(r"[^a-zA-Z]", "", s.lower())


@pytest.mark.parametrize("pos_label, nb_stats", [(None, 2), (1, 1)])
def test_estimator_report_report_metrics_binary(
binary_classification_data, binary_classification_data_svc, pos_label, nb_stats
):
"""Check the behaviour of the `report_metrics` method with binary
classification. We test both with an SVC that does not support `predict_proba` and a
RandomForestClassifier that does.
"""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, "report_metrics")
result = report.metrics.report_metrics(pos_label=pos_label)
def _check_results_report_metrics(result, expected_metrics, expected_nb_stats):
assert isinstance(result, pd.DataFrame)
expected_metrics = ("precision", "recall", "roc_auc", "brier_score")
# depending on `pos_label`, we report a stats for each class or not for
# precision and recall
expected_nb_stats = 2 * nb_stats + 2
assert len(result.columns) == expected_nb_stats

normalized_expected = {
Expand All @@ -566,29 +551,32 @@ def test_estimator_report_report_metrics_binary(
f" {expected_metrics}"
)


@pytest.mark.parametrize("pos_label, nb_stats", [(None, 2), (1, 1)])
def test_estimator_report_report_metrics_binary(
binary_classification_data, binary_classification_data_svc, pos_label, nb_stats
):
"""Check the behaviour of the `report_metrics` method with binary
classification. We test both with an SVC that does not support `predict_proba` and a
RandomForestClassifier that does.
"""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
result = report.metrics.report_metrics(pos_label=pos_label)
expected_metrics = ("precision", "recall", "roc_auc", "brier_score")
# depending on `pos_label`, we report a stats for each class or not for
# precision and recall
expected_nb_stats = 2 * nb_stats + 2
_check_results_report_metrics(result, expected_metrics, expected_nb_stats)

estimator, X_test, y_test = binary_classification_data_svc
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, "report_metrics")
result = report.metrics.report_metrics(pos_label=pos_label)
assert isinstance(result, pd.DataFrame)
expected_metrics = ("precision", "recall", "roc_auc")
# depending on `pos_label`, we report a stats for each class or not for
# precision and recall
expected_nb_stats = 2 * nb_stats + 1
assert len(result.columns) == expected_nb_stats

normalized_expected = {
_normalize_metric_name(metric) for metric in expected_metrics
}
for column in result.columns:
normalized_column = _normalize_metric_name(column)
matches = [
metric for metric in normalized_expected if metric == normalized_column
]
assert len(matches) == 1, (
f"No match found for column '{column}' in expected metrics: "
f" {expected_metrics}"
)
_check_results_report_metrics(result, expected_metrics, expected_nb_stats)


def test_estimator_report_report_metrics_multiclass(
Expand All @@ -599,71 +587,30 @@ def test_estimator_report_report_metrics_multiclass(
"""
estimator, X_test, y_test = multiclass_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, "report_metrics")
result = report.metrics.report_metrics()
assert isinstance(result, pd.DataFrame)
expected_metrics = ("precision", "recall", "roc_auc", "log_loss")
# since we are not averaging by default, we report 3 statistics for
# precision, recall and roc_auc
assert len(result.columns) == 10

normalized_expected = {
_normalize_metric_name(metric) for metric in expected_metrics
}
for column in result.columns:
normalized_column = _normalize_metric_name(column)
matches = [
metric for metric in normalized_expected if metric == normalized_column
]
assert len(matches) == 1, (
f"No match found for column '{column}' in expected metrics: "
f" {expected_metrics}"
)
expected_nb_stats = 3 * 3 + 1
_check_results_report_metrics(result, expected_metrics, expected_nb_stats)

estimator, X_test, y_test = multiclass_classification_data_svc
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, "report_metrics")
result = report.metrics.report_metrics()
assert isinstance(result, pd.DataFrame)
expected_metrics = ("precision", "recall")
assert len(result.columns) == 6

normalized_expected = {
_normalize_metric_name(metric) for metric in expected_metrics
}
for column in result.columns:
normalized_column = _normalize_metric_name(column)
matches = [
metric for metric in normalized_expected if metric == normalized_column
]
assert len(matches) == 1, (
f"No match found for column '{column}' in expected metrics: "
f" {expected_metrics}"
)
# since we are not averaging by default, we report 3 statistics for
# precision and recall
expected_nb_stats = 3 * 2
_check_results_report_metrics(result, expected_metrics, expected_nb_stats)


def test_estimator_report_report_metrics_regression(regression_data):
"""Check the behaviour of the `report_metrics` method with regression."""
estimator, X_test, y_test = regression_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
assert hasattr(report.metrics, "report_metrics")
result = report.metrics.report_metrics()
assert isinstance(result, pd.DataFrame)
expected_metrics = ("r2", "rmse")
assert len(result.columns) == len(expected_metrics)

normalized_expected = {
_normalize_metric_name(metric) for metric in expected_metrics
}
for column in result.columns:
normalized_column = _normalize_metric_name(column)
matches = [
metric for metric in normalized_expected if metric == normalized_column
]
assert len(matches) == 1, (
f"No match found for column '{column}' in expected metrics: "
f" {expected_metrics}"
)
_check_results_report_metrics(result, expected_metrics, len(expected_metrics))


def test_estimator_report_report_metrics_scoring_kwargs(
Expand Down

0 comments on commit e46ad58

Please sign in to comment.