Skip to content

Commit

Permalink
check that we support X_y without passing original dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
glemaitre committed Jan 9, 2025
1 parent cb5f210 commit b8d4610
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
23 changes: 18 additions & 5 deletions skore/src/skore/sklearn/_estimator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,41 @@ def _get_X_y_and_data_source_hash(self, *, data_source, X=None, y=None):
The hash of the data source. None when we are able to track the data, and
thus relying on X_train, y_train, X_test, y_test.
"""
is_cluster = is_clusterer(self._parent.estimator)
if data_source == "test":
if not (X is None or y is None):
raise ValueError("X and y must be None when data_source is test.")
if self._parent._X_test is None or (
not is_cluster and self._parent._y_test is None
):
missing_data = "X_test" if is_cluster else "X_test and y_test"
raise ValueError(
f"No {data_source} data (i.e. {missing_data}) were provided "
f"when creating the reporter. Please provide the {data_source} "
"data either when creating the reporter or by setting data_source "
"to 'X_y' and providing X and y."
)
return self._parent._X_test, self._parent._y_test, None
elif data_source == "train":
if not (X is None or y is None):
raise ValueError("X and y must be None when data_source is train.")
is_cluster = is_clusterer(self._parent.estimator)
if self._parent._X_train is None or (
not is_cluster and self._parent._y_train is None
):
missing_data = "X_train" if is_cluster else "X_train and y_train"
raise ValueError(
f"No training data (i.e. {missing_data}) were provided "
"when creating the reporter. Please provide the training data."
f"No {data_source} data (i.e. {missing_data}) were provided "
f"when creating the reporter. Please provide the {data_source} "
"data either when creating the reporter or by setting data_source "
"to 'X_y' and providing X and y."
)
return self._parent._X_train, self._parent._y_train, None
elif data_source == "X_y":
is_cluster = is_clusterer(self._parent.estimator)
if X is None or (not is_cluster and y is None):
missing_data = "X" if is_cluster else "X and y"
raise ValueError(f"{missing_data} must be provided.")
raise ValueError(
f"{missing_data} must be provided when data_source is X_y."
)
return X, y, joblib.hash((X, y))
else:
raise ValueError(
Expand Down
39 changes: 23 additions & 16 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def test_estimator_report_display_binary_classification_external_data(
when passing external data.
"""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
report = EstimatorReport(estimator)
assert hasattr(report.metrics.plot, display)
display_first_call = getattr(report.metrics.plot, display)(
data_source="X_y", X=X_test, y=y_test
Expand All @@ -389,7 +389,7 @@ def test_estimator_report_display_regression_external_data(
external data.
"""
estimator, X_test, y_test = regression_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
report = EstimatorReport(estimator)
assert hasattr(report.metrics.plot, display)
display_first_call = getattr(report.metrics.plot, display)(
data_source="X_y", X=X_test, y=y_test
Expand Down Expand Up @@ -827,20 +827,23 @@ def test_estimator_report_get_X_y_and_data_source_hash_error():
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression().fit(X_train, y_train)
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
report = EstimatorReport(estimator)

err_msg = re.escape(
"Invalid data source: unknown. Possible values are: " "test, train, X_y."
)
with pytest.raises(ValueError, match=err_msg):
report.metrics.log_loss(data_source="unknown")

err_msg = re.escape(
"No training data (i.e. X_train and y_train) were provided "
"when creating the reporter. Please provide the training data."
)
with pytest.raises(ValueError, match=err_msg):
report.metrics.log_loss(data_source="train")
for data_source in ("train", "test"):
err_msg = re.escape(
f"No {data_source} data (i.e. X_{data_source} and y_{data_source}) were "
f"provided when creating the reporter. Please provide the {data_source} "
"data either when creating the reporter or by setting data_source to "
"'X_y' and providing X and y."
)
with pytest.raises(ValueError, match=err_msg):
report.metrics.log_loss(data_source=data_source)

report = EstimatorReport(
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
Expand All @@ -865,14 +868,18 @@ def test_estimator_report_get_X_y_and_data_source_hash_error():
rand_score, response_method="predict", data_source="X_y"
)

err_msg = re.escape(
"No training data (i.e. X_train) were provided when creating the reporter. "
"Please provide the training data."
)
with pytest.raises(ValueError, match=err_msg):
report.metrics.custom_metric(
rand_score, response_method="predict", data_source="train"
report = EstimatorReport(estimator)
for data_source in ("train", "test"):
err_msg = re.escape(
f"No {data_source} data (i.e. X_{data_source}) were provided when "
f"creating the reporter. Please provide the {data_source} data either "
f"when creating the reporter or by setting data_source to 'X_y' and "
f"providing X and y."
)
with pytest.raises(ValueError, match=err_msg):
report.metrics.custom_metric(
rand_score, response_method="predict", data_source=data_source
)


@pytest.mark.parametrize("data_source", ("train", "test", "X_y"))
Expand Down

0 comments on commit b8d4610

Please sign in to comment.