diff --git a/skore/src/skore/sklearn/train_test_split/train_test_split.py b/skore/src/skore/sklearn/train_test_split/train_test_split.py index 48b54bb54..267cdad35 100644 --- a/skore/src/skore/sklearn/train_test_split/train_test_split.py +++ b/skore/src/skore/sklearn/train_test_split/train_test_split.py @@ -7,6 +7,7 @@ import numpy as np from numpy.random import RandomState +from rich.panel import Panel from skore.sklearn.find_ml_task import _find_ml_task from skore.sklearn.train_test_split.warning import TRAIN_TEST_SPLIT_WARNINGS @@ -158,10 +159,24 @@ class labels. ml_task=ml_task, ) + from skore import console # avoid circular import + for warning_class in TRAIN_TEST_SPLIT_WARNINGS: warning = warning_class.check(**kwargs) - if warning is not None: - warnings.warn(message=warning, category=warning_class, stacklevel=1) + if warning is not None and ( + not warnings.filters + or not any( + f[0] == "ignore" and f[2] == warning_class for f in warnings.filters + ) + ): + console.print( + Panel( + title=warning_class.__name__, + renderable=warning, + style="orange1", + border_style="cyan", + ) + ) return output diff --git a/skore/tests/unit/sklearn/train_test_split/test_train_test_split.py b/skore/tests/unit/sklearn/train_test_split/test_train_test_split.py index b61aa7af3..100310342 100644 --- a/skore/tests/unit/sklearn/train_test_split/test_train_test_split.py +++ b/skore/tests/unit/sklearn/train_test_split/test_train_test_split.py @@ -129,15 +129,45 @@ def case_time_based_column_polars_dates(): case_time_based_column_polars_dates, ], ) -def test_train_test_split_warns(params): +def test_train_test_split_warns(params, capsys): """When train_test_split is called with these args and kwargs, the corresponding - warning should fire.""" - warnings.simplefilter("ignore") + warning should be printed to the console.""" + args, kwargs, warning_cls = params() + + train_test_split(*args, **kwargs) + + captured = capsys.readouterr() + assert warning_cls.__name__ in captured.out + + +@pytest.mark.parametrize( + "params", + [ + case_high_class_imbalance, + case_high_class_imbalance_too_few_examples, + case_high_class_imbalance_too_few_examples_kwargs, + case_high_class_imbalance_too_few_examples_kwargs_mixed, + case_stratify, + case_random_state_unset, + case_shuffle_true, + case_shuffle_none, + case_time_based_column, + case_time_based_columns_several, + case_time_based_column_polars, + case_time_based_column_polars_dates, + ], +) +def test_train_test_split_warns_suppressed(params, capsys): + """Verify that warnings can be suppressed and don't appear in the console output.""" args, kwargs, warning_cls = params() - with pytest.warns(warning_cls): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=warning_cls) train_test_split(*args, **kwargs) + captured = capsys.readouterr() + assert warning_cls.__name__ not in captured.out + def test_train_test_split_kwargs(): """Passing data by keyword arguments should produce the same results as passing