Skip to content

Commit

Permalink
fix(EstimatorReport): Deepcopy estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
auguste-probabl committed Jan 10, 2025
1 parent 4c819ec commit 5b9b55a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
14 changes: 14 additions & 0 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import copy
import inspect
import time
import warnings
from itertools import product

import joblib
Expand Down Expand Up @@ -100,6 +102,18 @@ def __init__(
else: # fit is False
self._estimator = estimator

try:
self._estimator = copy.deepcopy(self._estimator)
except Exception as e:
warnings.warn(
"Deepcopy failed; using estimator as-is. "
"Be aware that modifying the estimator outside of "
f"{self.__class__.__name__} will modify the internal estimator. "
"Consider using a FrozenEstimator from scikit-learn to prevent this. "
f"Original error: {e}",
stacklevel=1,
)

# private storage to be able to invalidate the cache when the user alters
# those attributes
self._X_train = X_train
Expand Down
26 changes: 24 additions & 2 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def test_estimator_report_from_fitted_estimator(binary_classification_data, fit)
estimator, X, y = binary_classification_data
report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, RandomForestClassifier)
assert report.X_train is None
assert report.y_train is None
assert report.X_test is X
Expand All @@ -209,7 +210,8 @@ def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeli
estimator, X, y = binary_classification_data_pipeline
report = EstimatorReport(estimator, X_test=X, y_test=y)

assert report.estimator is estimator # we should not clone the estimator
check_is_fitted(report.estimator)
assert isinstance(report.estimator, Pipeline)
assert report.estimator_name == estimator[-1].__class__.__name__
assert report.X_train is None
assert report.y_train is None
Expand Down Expand Up @@ -950,3 +952,23 @@ def test_estimator_has_side_effects():
predictions_after = report.estimator.predict_proba(X_test)

np.testing.assert_array_equal(predictions_before, predictions_after)


def test_estimator_has_no_deep_copy():
X, y = make_classification(n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

estimator = LogisticRegression()
# Make it so deepcopy does not work
estimator.__reduce_ex__ = None
estimator.__reduce__ = None

with pytest.warns(UserWarning, match="Deepcopy failed"):
EstimatorReport(
estimator,
fit=False,
X_train=X_train,
X_test=X_test,
y_train=y_train,
y_test=y_test,
)

0 comments on commit 5b9b55a

Please sign in to comment.