Skip to content

Commit

Permalink
feat: Use rich Panel for showing warning in train_test_split
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Jan 10, 2025
1 parent 1a4151a commit a32c62d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 6 deletions.
19 changes: 17 additions & 2 deletions skore/src/skore/sklearn/train_test_split/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from numpy.random import RandomState
from rich.panel import Panel

from skore.sklearn.find_ml_task import _find_ml_task
from skore.sklearn.train_test_split.warning import TRAIN_TEST_SPLIT_WARNINGS
Expand Down Expand Up @@ -158,10 +159,24 @@ class labels.
ml_task=ml_task,
)

from skore import console # avoid circular import

for warning_class in TRAIN_TEST_SPLIT_WARNINGS:
warning = warning_class.check(**kwargs)

if warning is not None:
warnings.warn(message=warning, category=warning_class, stacklevel=1)
if warning is not None and (
not warnings.filters
or not any(
f[0] == "ignore" and f[2] == warning_class for f in warnings.filters
)
):
console.print(
Panel(
title=warning_class.__name__,
renderable=warning,
style="orange1",
border_style="cyan",
)
)

return output
38 changes: 34 additions & 4 deletions skore/tests/unit/sklearn/train_test_split/test_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,45 @@ def case_time_based_column_polars_dates():
case_time_based_column_polars_dates,
],
)
def test_train_test_split_warns(params):
def test_train_test_split_warns(params, capsys):
"""When train_test_split is called with these args and kwargs, the corresponding
warning should fire."""
warnings.simplefilter("ignore")
warning should be printed to the console."""
args, kwargs, warning_cls = params()

train_test_split(*args, **kwargs)

captured = capsys.readouterr()
assert warning_cls.__name__ in captured.out


@pytest.mark.parametrize(
"params",
[
case_high_class_imbalance,
case_high_class_imbalance_too_few_examples,
case_high_class_imbalance_too_few_examples_kwargs,
case_high_class_imbalance_too_few_examples_kwargs_mixed,
case_stratify,
case_random_state_unset,
case_shuffle_true,
case_shuffle_none,
case_time_based_column,
case_time_based_columns_several,
case_time_based_column_polars,
case_time_based_column_polars_dates,
],
)
def test_train_test_split_warns_suppressed(params, capsys):
"""Verify that warnings can be suppressed and don't appear in the console output."""
args, kwargs, warning_cls = params()

with pytest.warns(warning_cls):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=warning_cls)
train_test_split(*args, **kwargs)

captured = capsys.readouterr()
assert warning_cls.__name__ not in captured.out


def test_train_test_split_kwargs():
"""Passing data by keyword arguments should produce the same results as passing
Expand Down

0 comments on commit a32c62d

Please sign in to comment.