Skip to content

Commit

Permalink
feat: Use rich Panel for showing warning in train_test_split (#1086)
Browse files Browse the repository at this point in the history
closes #1060 

It is the alternative to #1060 using `rich`. I added a test to check
that we can filter the warning since we are not using the usual
`warnings` module.

In the future, we could factor out the code in a utils to be sure that
we can also transform the warnings into error.
  • Loading branch information
glemaitre authored Jan 10, 2025
1 parent b177651 commit 6f6fafb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
32 changes: 24 additions & 8 deletions skore/src/skore/sklearn/train_test_split/train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from numpy.random import RandomState
from rich.panel import Panel

from skore.sklearn.find_ml_task import _find_ml_task
from skore.sklearn.train_test_split.warning import TRAIN_TEST_SPLIT_WARNINGS
Expand Down Expand Up @@ -88,26 +89,27 @@ class labels.
>>> X, y = np.arange(10).reshape((5, 2)), range(5)
>>> # Drop-in replacement for sklearn train_test_split
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, # doctest: +SKIP
... test_size=0.33, random_state=42)
>>> X_train
>>> X_train # doctest: +SKIP
array([[4, 5],
[0, 1],
[6, 7]])
>>> # Explicit X and y, makes detection of problems easier
>>> X_train, X_test, y_train, y_test = train_test_split(X=X, y=y,
>>> X_train, X_test, y_train, y_test = train_test_split(X=X, y=y, # doctest: +SKIP
... test_size=0.33, random_state=42)
>>> X_train
>>> X_train # doctest: +SKIP
array([[4, 5],
[0, 1],
[6, 7]])
>>> # When passing X and y explicitly, X is returned before y
>>> arr = np.arange(10).reshape((5, 2))
>>> arr_train, arr_test, X_train, X_test, y_train, y_test = train_test_split(
>>> splits = train_test_split( # doctest: +SKIP
... arr, y=y, X=X, test_size=0.33, random_state=42)
>>> X_train
>>> arr_train, arr_test, X_train, X_test, y_train, y_test = splits # doctest: +SKIP
>>> X_train # doctest: +SKIP
array([[4, 5],
[0, 1],
[6, 7]])
Expand Down Expand Up @@ -158,10 +160,24 @@ class labels.
ml_task=ml_task,
)

from skore import console # avoid circular import

for warning_class in TRAIN_TEST_SPLIT_WARNINGS:
warning = warning_class.check(**kwargs)

if warning is not None:
warnings.warn(message=warning, category=warning_class, stacklevel=1)
if warning is not None and (
not warnings.filters
or not any(
f[0] == "ignore" and f[2] == warning_class for f in warnings.filters
)
):
console.print(
Panel(
title=warning_class.__name__,
renderable=warning,
style="orange1",
border_style="cyan",
)
)

return output
38 changes: 34 additions & 4 deletions skore/tests/unit/sklearn/train_test_split/test_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,45 @@ def case_time_based_column_polars_dates():
case_time_based_column_polars_dates,
],
)
def test_train_test_split_warns(params):
def test_train_test_split_warns(params, capsys):
"""When train_test_split is called with these args and kwargs, the corresponding
warning should fire."""
warnings.simplefilter("ignore")
warning should be printed to the console."""
args, kwargs, warning_cls = params()

train_test_split(*args, **kwargs)

captured = capsys.readouterr()
assert warning_cls.__name__ in captured.out


@pytest.mark.parametrize(
"params",
[
case_high_class_imbalance,
case_high_class_imbalance_too_few_examples,
case_high_class_imbalance_too_few_examples_kwargs,
case_high_class_imbalance_too_few_examples_kwargs_mixed,
case_stratify,
case_random_state_unset,
case_shuffle_true,
case_shuffle_none,
case_time_based_column,
case_time_based_columns_several,
case_time_based_column_polars,
case_time_based_column_polars_dates,
],
)
def test_train_test_split_warns_suppressed(params, capsys):
"""Verify that warnings can be suppressed and don't appear in the console output."""
args, kwargs, warning_cls = params()

with pytest.warns(warning_cls):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=warning_cls)
train_test_split(*args, **kwargs)

captured = capsys.readouterr()
assert warning_cls.__name__ not in captured.out


def test_train_test_split_kwargs():
"""Passing data by keyword arguments should produce the same results as passing
Expand Down

0 comments on commit 6f6fafb

Please sign in to comment.