From 1a4151a8ee06b1077f13fd915b24e17261b5c4e4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 10 Jan 2025 09:18:48 +0100 Subject: [PATCH] feat: Design of `EstimatorReport` (#997) closes https://github.com/probabl-ai/skore/issues/834 Investigate an API for a `EstimatorReport`. #### TODO - [x] Metrics - [x] handle string metrics has specified in the accessor - [x] handle callable metrics - [x] handle scikit-learn scorers - [x] use efficiently the cache as much as possible - [x] add testing for all of those features - [x] allow to pass new validation set to functions instead of using the internal validation set - [x] add a proper help and rich `__repr__` - [x] Plots - [x] add the roc curve display - [x] add the precision recall curve display - [x] add prediction error display for regressor - [x] make proper testing for those displays - [x] add a proper `__repr__` for those displays - [x] Documentation - [x] (done for the checked part) add an example to showcase all the different features - [x] find a way to show the accessors documentation in the page of `EstimatorReport`. It could be a bit tricky because they are only defined once the instance created. - We need to have a look at the `series.rst` page from pandas to see how they document this sort of pattern. - [x] check the autocompletion: when typing `report.metrics.->tab` it should provide the autocompetion. **edit**: having a stub file is actually working. I prefer this than type hints directly in the file. - Open questions - [x] we use hashing to retrieve external set. - use the caching for the external validation set? To make it work we need to compute the hash of potentially big arrays. This might more costly than making the model predict. #### Notes This PR build upon: - https://github.com/probabl-ai/skore/pull/962 to reuse the `skore.console` - https://github.com/probabl-ai/skore/pull/998 to be able to detect clusterer in a consistent manner. --- .../model_evaluation/plot_estimator_report.py | 385 ++++++ skore/pyproject.toml | 7 +- skore/src/skore/__init__.py | 3 +- .../src/skore/externals/_pandas_accessors.py | 53 + skore/src/skore/sklearn/__init__.py | 2 + .../src/skore/sklearn/_estimator/__init__.py | 25 + .../src/skore/sklearn/_estimator/__init__.pyi | 3 + skore/src/skore/sklearn/_estimator/base.py | 174 +++ skore/src/skore/sklearn/_estimator/base.pyi | 33 + .../sklearn/_estimator/metrics_accessor.py | 1096 +++++++++++++++++ .../sklearn/_estimator/metrics_accessor.pyi | 168 +++ skore/src/skore/sklearn/_estimator/report.py | 413 +++++++ skore/src/skore/sklearn/_estimator/report.pyi | 71 ++ skore/src/skore/sklearn/_estimator/utils.py | 19 + skore/src/skore/sklearn/_plot/__init__.py | 9 + .../sklearn/_plot/precision_recall_curve.py | 511 ++++++++ .../skore/sklearn/_plot/prediction_error.py | 318 +++++ skore/src/skore/sklearn/_plot/roc_curve.py | 399 ++++++ skore/src/skore/sklearn/_plot/utils.py | 200 +++ skore/src/skore/utils/_accessor.py | 15 + skore/tests/conftest.py | 25 + skore/tests/unit/sklearn/plot/test_common.py | 80 ++ .../plot/test_precision_recall_curve.py | 238 ++++ .../sklearn/plot/test_prediction_error.py | 136 ++ .../tests/unit/sklearn/plot/test_roc_curve.py | 199 +++ skore/tests/unit/sklearn/plot/test_utils.py | 65 + skore/tests/unit/sklearn/test_estimator.py | 927 ++++++++++++++ skore/tests/unit/utils/test_accessors.py | 60 + sphinx/_templates/autosummary/accessor.rst | 5 + .../autosummary/accessor_attribute.rst | 5 + .../autosummary/accessor_callable.rst | 5 + .../autosummary/accessor_method.rst | 5 + sphinx/api.rst | 56 + sphinx/conf.py | 10 +- 34 files changed, 5713 insertions(+), 7 deletions(-) create mode 100644 examples/model_evaluation/plot_estimator_report.py create mode 100644 skore/src/skore/externals/_pandas_accessors.py create mode 100644 skore/src/skore/sklearn/_estimator/__init__.py create mode 100644 skore/src/skore/sklearn/_estimator/__init__.pyi create mode 100644 skore/src/skore/sklearn/_estimator/base.py create mode 100644 skore/src/skore/sklearn/_estimator/base.pyi create mode 100644 skore/src/skore/sklearn/_estimator/metrics_accessor.py create mode 100644 skore/src/skore/sklearn/_estimator/metrics_accessor.pyi create mode 100644 skore/src/skore/sklearn/_estimator/report.py create mode 100644 skore/src/skore/sklearn/_estimator/report.pyi create mode 100644 skore/src/skore/sklearn/_estimator/utils.py create mode 100644 skore/src/skore/sklearn/_plot/__init__.py create mode 100644 skore/src/skore/sklearn/_plot/precision_recall_curve.py create mode 100644 skore/src/skore/sklearn/_plot/prediction_error.py create mode 100644 skore/src/skore/sklearn/_plot/roc_curve.py create mode 100644 skore/src/skore/sklearn/_plot/utils.py create mode 100644 skore/src/skore/utils/_accessor.py create mode 100644 skore/tests/unit/sklearn/plot/test_common.py create mode 100644 skore/tests/unit/sklearn/plot/test_precision_recall_curve.py create mode 100644 skore/tests/unit/sklearn/plot/test_prediction_error.py create mode 100644 skore/tests/unit/sklearn/plot/test_roc_curve.py create mode 100644 skore/tests/unit/sklearn/plot/test_utils.py create mode 100644 skore/tests/unit/sklearn/test_estimator.py create mode 100644 skore/tests/unit/utils/test_accessors.py create mode 100644 sphinx/_templates/autosummary/accessor.rst create mode 100644 sphinx/_templates/autosummary/accessor_attribute.rst create mode 100644 sphinx/_templates/autosummary/accessor_callable.rst create mode 100644 sphinx/_templates/autosummary/accessor_method.rst diff --git a/examples/model_evaluation/plot_estimator_report.py b/examples/model_evaluation/plot_estimator_report.py new file mode 100644 index 000000000..fb8fd30f2 --- /dev/null +++ b/examples/model_evaluation/plot_estimator_report.py @@ -0,0 +1,385 @@ +""" +============================================ +Get insights from any scikit-learn estimator +============================================ + +This example shows how the :class:`skore.EstimatorReport` class can be used to +quickly get insights from any scikit-learn estimator. +""" + +# %% +# +# TODO: we need to describe the aim of this classification problem. +from skrub.datasets import fetch_open_payments + +dataset = fetch_open_payments() +df = dataset.X +y = dataset.y + +# %% +from skrub import TableReport + +TableReport(df) + +# %% +TableReport(y.to_frame()) + +# %% +# Looking at the distributions of the target, we observe that this classification +# task is quite imbalanced. It means that we have to be careful when selecting a set +# of statistical metrics to evaluate the classification performance of our predictive +# model. In addition, we see that the class labels are not specified by an integer +# 0 or 1 but instead by a string "allowed" or "disallowed". +# +# For our application, the label of interest is "allowed". +pos_label, neg_label = "allowed", "disallowed" + +# %% +# Before training a predictive model, we need to split our dataset into a training +# and a validation set. +from skore import train_test_split + +X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42) + +# %% +# TODO: we have a perfect case to show useful feature of the `train_test_split` +# function from `skore`. +# +# Now, we need to define a predictive model. Hopefully, `skrub` provides a convenient +# function (:func:`skrub.tabular_learner`) when it comes to getting strong baseline +# predictive models with a single line of code. As its feature engineering is generic, +# it does not provide some handcrafted and tailored feature engineering but still +# provides a good starting point. +# +# So let's create a classifier for our task and fit it on the training set. +from skrub import tabular_learner + +estimator = tabular_learner("classifier").fit(X_train, y_train) +estimator + +# %% +# +# Introducing the :class:`skore.EstimatorReport` class +# ---------------------------------------------------- +# +# Now, we would be interested in getting some insights from our predictive model. +# One way is to use the :class:`skore.EstimatorReport` class. This constructor will +# detect that our estimator is already fitted and will not fit it again. +from skore import EstimatorReport + +reporter = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test +) +reporter + +# %% +# +# Once the reporter is created, we get some information regarding the available tools +# allowing us to get some insights from our specific model on the specific task. +# +# You can get a similar information if you call the :meth:`~skore.EstimatorReport.help` +# method. +reporter.help() + +# %% +# +# Be aware that you can access the help for each individual sub-accessor. For instance: +reporter.metrics.help() + +# %% +reporter.metrics.plot.help() + +# %% +# +# Metrics computation with aggressive caching +# ------------------------------------------- +# +# At this point, we might be interested to have a first look at the statistical +# performance of our model on the validation set that we provided. We can access it +# by calling any of the metrics displayed above. Since we are greedy, we want to get +# several metrics at once and we will use the +# :meth:`~skore.EstimatorReport.metrics.report_metrics` method. +import time + +start = time.time() +metric_report = reporter.metrics.report_metrics(pos_label=pos_label) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# An interesting feature provided by the :class:`skore.EstimatorReport` is the +# the caching mechanism. Indeed, when we have a large enough dataset, computing the +# predictions for a model is not cheap anymore. For instance, on our smallish dataset, +# it took a couple of seconds to compute the metrics. The reporter will cache the +# predictions and if you are interested in computing a metric again or an alternative +# metric that requires the same predictions, it will be faster. Let's check by +# requesting the same metrics report again. + +start = time.time() +metric_report = reporter.metrics.report_metrics(pos_label=pos_label) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# Since we obtain a pandas dataframe, we can also use the plotting interface of +# pandas. +import matplotlib.pyplot as plt + +ax = metric_report.T.plot.barh() +ax.set_title("Metrics report") +plt.tight_layout() + +# %% +# +# Whenever computing a metric, we check if the predictions are available in the cache +# and reload them if available. So for instance, let's compute the log loss. + +start = time.time() +log_loss = reporter.metrics.log_loss() +end = time.time() +log_loss + +# %% +print(f"Time taken to compute the log loss: {end - start:.2f} seconds") + +# %% +# +# We can show that without initial cache, it would have taken more time to compute +# the log loss. +reporter.clean_cache() + +start = time.time() +log_loss = reporter.metrics.log_loss() +end = time.time() +log_loss + +# %% +print(f"Time taken to compute the log loss: {end - start:.2f} seconds") + +# %% +# +# By default, the metrics are computed on the test set. However, if a training set +# is provided, we can also compute the metrics by specifying the `data_source` +# parameter. +reporter.metrics.log_loss(data_source="train") + +# %% +# +# In the case where we are interested in computing the metrics on a completely new set +# of data, we can use the `data_source="X_y"` parameter. In addition, we need to provide +# a `X` and `y` parameters. + +start = time.time() +metric_report = reporter.metrics.report_metrics( + data_source="X_y", X=X_test, y=y_test, pos_label=pos_label +) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# As in the other case, we rely on the cache to avoid recomputing the predictions. +# Internally, we compute a hash of the input data to be sure that we can hit the cache +# in a consistent way. + +# %% +start = time.time() +metric_report = reporter.metrics.report_metrics( + data_source="X_y", X=X_test, y=y_test, pos_label=pos_label +) +end = time.time() +metric_report + +# %% +print(f"Time taken to compute the metrics: {end - start:.2f} seconds") + +# %% +# +# .. warning:: +# In this last example, we rely on computing the hash of the input data. Therefore, +# there is a trade-off: the computation of the hash is not free and it might be +# faster to compute the predictions instead. +# +# Be aware that you can also benefit from the caching mechanism with your own custom +# metrics. We only expect that you define your own metric function to take `y_true` +# and `y_pred` as the first two positional arguments. It can take any other arguments. +# Let's see an example. + + +def operational_decision_cost(y_true, y_pred, amount): + mask_true_positive = (y_true == pos_label) & (y_pred == pos_label) + mask_true_negative = (y_true == neg_label) & (y_pred == neg_label) + mask_false_positive = (y_true == neg_label) & (y_pred == pos_label) + mask_false_negative = (y_true == pos_label) & (y_pred == neg_label) + # FIXME: we need to make sense of the cost sensitive part with the right naming + fraudulent_refuse = mask_true_positive.sum() * 50 + fraudulent_accept = -amount[mask_false_negative].sum() + legitimate_refuse = mask_false_positive.sum() * -5 + legitimate_accept = (amount[mask_true_negative] * 0.02).sum() + return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept + + +# %% +# +# In our use case, we have a operational decision to make that translate the +# classification outcome into a cost. It translate the confusion matrix into a cost +# matrix based on some amount linked to each sample in the dataset that are provided to +# us. Here, we randomly generate some amount as an illustration. +import numpy as np + +rng = np.random.default_rng(42) +amount = rng.integers(low=100, high=1000, size=len(y_test)) + +# %% +# +# Let's make sure that a function called the `predict` method and cached the result. +# We compute the accuracy metric to make sure that the `predict` method is called. +reporter.metrics.accuracy() + +# %% +# +# We can now compute the cost of our operational decision. +start = time.time() +cost = reporter.metrics.custom_metric( + metric_function=operational_decision_cost, + metric_name="Operational Decision Cost", + response_method="predict", + amount=amount, +) +end = time.time() +cost + +# %% +print(f"Time taken to compute the cost: {end - start:.2f} seconds") + +# %% +# +# Let's now clean the cache and see if it is faster. +reporter.clean_cache() + +# %% +start = time.time() +cost = reporter.metrics.custom_metric( + metric_function=operational_decision_cost, + metric_name="Operational Decision Cost", + response_method="predict", + amount=amount, +) +end = time.time() +cost + +# %% +print(f"Time taken to compute the cost: {end - start:.2f} seconds") + +# %% +# +# We observe that caching is working as expected. It is really handy because it means +# that you can compute some additional metrics without having to recompute the +# the predictions. +reporter.metrics.report_metrics( + scoring=["precision", "recall", operational_decision_cost], + pos_label=pos_label, + scoring_kwargs={ + "amount": amount, + "response_method": "predict", + "metric_name": "Operational Decision Cost", + }, +) + +# %% +# +# It could happen that you are interested in providing several custom metrics which +# does not necessarily share the same parameters. In this more complex case, we will +# require you to provide a scorer using the :func:`sklearn.metrics.make_scorer` +# function. +from sklearn.metrics import make_scorer, f1_score + +f1_scorer = make_scorer( + f1_score, + response_method="predict", + metric_name="F1 Score", + pos_label=pos_label, +) +operational_decision_cost_scorer = make_scorer( + operational_decision_cost, + response_method="predict", + metric_name="Operational Decision Cost", + amount=amount, +) +reporter.metrics.report_metrics(scoring=[f1_scorer, operational_decision_cost_scorer]) + +# %% +# +# Effortless one-liner plotting +# ----------------------------- +# +# The :class:`skore.EstimatorReport` class also provides a plotting interface that +# allows to plot *defacto* the most common plots. As for the the metrics, we only +# provide the meaningful set of plots for the provided estimator. +reporter.metrics.plot.help() + +# %% +# +# Let's start by plotting the ROC curve for our binary classification task. +display = reporter.metrics.plot.roc(pos_label=pos_label) +plt.tight_layout() + +# %% +# +# The plot functionality is built upon the scikit-learn display objects. We return +# those display (slightly modified to improve the UI) in case you want to tweak some +# of the plot properties. You can have quick look at the available attributes and +# methods by calling the `help` method or simply by printing the display. +display + +# %% +display.help() + +# %% +display.plot() +display.ax_.set_title("Example of a ROC curve") +display.figure_ +plt.tight_layout() + +# %% +# +# Similarly to the metrics, we aggressively use the caching to avoid recomputing the +# predictions of the model. We also cache the plot display object by detection if the +# input parameters are the same as the previous call. Let's demonstrate the kind of +# performance gain we can get. +start = time.time() +# we already trigger the computation of the predictions in a previous call +reporter.metrics.plot.roc(pos_label=pos_label) +plt.tight_layout() +end = time.time() + +# %% +print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") + +# %% +# +# Now, let's clean the cache and check if we get a slowdown. +reporter.clean_cache() + +# %% +start = time.time() +reporter.metrics.plot.roc(pos_label=pos_label) +plt.tight_layout() +end = time.time() + +# %% +print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") + +# %% +# As expected, since we need to recompute the predictions, it takes more time. diff --git a/skore/pyproject.toml b/skore/pyproject.toml index 085d03c65..37d8124af 100644 --- a/skore/pyproject.toml +++ b/skore/pyproject.toml @@ -8,6 +8,8 @@ dependencies = [ "diskcache", "fastapi", "numpy", + "pandas", + "matplotlib", "plotly>=5,<6", "pyarrow", "rich", @@ -66,8 +68,6 @@ artifacts = ["src/skore/ui/static/"] test = [ "altair>=5,<6", "httpx", - "matplotlib", - "pandas", "pillow", "plotly", "polars", @@ -85,13 +85,12 @@ test = [ sphinx = [ "IPython", "altair", - "matplotlib", "numpydoc", - "pandas", "polars", "kaleido", "pydata-sphinx-theme", "sphinx", + "sphinx_autosummary_accessors", "sphinx-design", "sphinx-gallery", "sphinx-copybutton", diff --git a/skore/src/skore/__init__.py b/skore/src/skore/__init__.py index 1f33543cf..135bd0947 100644 --- a/skore/src/skore/__init__.py +++ b/skore/src/skore/__init__.py @@ -6,11 +6,12 @@ from rich.theme import Theme from skore.project import Project, open -from skore.sklearn import CrossValidationReporter, train_test_split +from skore.sklearn import CrossValidationReporter, EstimatorReport, train_test_split from skore.utils._show_versions import show_versions __all__ = [ "CrossValidationReporter", + "EstimatorReport", "open", "Project", "show_versions", diff --git a/skore/src/skore/externals/_pandas_accessors.py b/skore/src/skore/externals/_pandas_accessors.py new file mode 100644 index 000000000..7dabdc6bc --- /dev/null +++ b/skore/src/skore/externals/_pandas_accessors.py @@ -0,0 +1,53 @@ +"""Pandas-like accessors. + +This code is copied from: +https://github.com/pandas-dev/pandas/blob/main/pandas/core/accessor.py + +It is used to register accessors for the skore classes. +""" + +from typing import final + + +class DirNamesMixin: + _accessors: set[str] = set() + _hidden_attrs: frozenset[str] = frozenset() + + @final + def _dir_deletions(self) -> set[str]: + return self._accessors | self._hidden_attrs + + def _dir_additions(self) -> set[str]: + return {accessor for accessor in self._accessors if hasattr(self, accessor)} + + def __dir__(self) -> list[str]: + rv = set(super().__dir__()) + rv = (rv - self._dir_deletions()) | self._dir_additions() + return sorted(rv) + + +class Accessor: + def __init__(self, name: str, accessor) -> None: + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + # we're accessing the attribute of the class, i.e., Dataset.geo + return self._accessor + return self._accessor(obj) + + +def _register_accessor(name, cls): + def decorator(accessor): + if hasattr(cls, name): + raise ValueError( + f"registration of accessor {accessor!r} under name " + f"{name!r} for type {cls!r} is overriding a preexisting " + f"attribute with the same name." + ) + setattr(cls, name, Accessor(name, accessor)) + cls._accessors.add(name) + return accessor + + return decorator diff --git a/skore/src/skore/sklearn/__init__.py b/skore/src/skore/sklearn/__init__.py index 9331d60a8..eb3d2188f 100644 --- a/skore/src/skore/sklearn/__init__.py +++ b/skore/src/skore/sklearn/__init__.py @@ -1,9 +1,11 @@ """Enhance `sklearn` functions.""" +from skore.sklearn._estimator import EstimatorReport from skore.sklearn.cross_validation import CrossValidationReporter from skore.sklearn.train_test_split.train_test_split import train_test_split __all__ = [ "train_test_split", "CrossValidationReporter", + "EstimatorReport", ] diff --git a/skore/src/skore/sklearn/_estimator/__init__.py b/skore/src/skore/sklearn/_estimator/__init__.py new file mode 100644 index 000000000..ba7a3058a --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/__init__.py @@ -0,0 +1,25 @@ +from skore.externals._pandas_accessors import _register_accessor +from skore.sklearn._estimator.metrics_accessor import ( + _MetricsAccessor, + _PlotMetricsAccessor, +) +from skore.sklearn._estimator.report import EstimatorReport + + +def register_estimator_report_accessor(name: str): + """Register an accessor for the EstimatorReport class.""" + return _register_accessor(name, EstimatorReport) + + +def register_metrics_accessor(name: str): + """Register an accessor for the EstimatorReport class.""" + return _register_accessor(name, _MetricsAccessor) + + +# add the plot accessor to the metrics accessor +register_metrics_accessor("plot")(_PlotMetricsAccessor) + +# add the metrics accessor to the estimator report +register_estimator_report_accessor("metrics")(_MetricsAccessor) + +__all__ = ["EstimatorReport"] diff --git a/skore/src/skore/sklearn/_estimator/__init__.pyi b/skore/src/skore/sklearn/_estimator/__init__.pyi new file mode 100644 index 000000000..f496ff1f3 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/__init__.pyi @@ -0,0 +1,3 @@ +from skore.sklearn._estimator.report import EstimatorReport + +__all__ = ["EstimatorReport"] diff --git a/skore/src/skore/sklearn/_estimator/base.py b/skore/src/skore/sklearn/_estimator/base.py new file mode 100644 index 000000000..fffb67c11 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/base.py @@ -0,0 +1,174 @@ +import inspect +from io import StringIO + +import joblib +from rich.console import Console, Group +from rich.panel import Panel +from rich.tree import Tree + +from skore.externals._sklearn_compat import is_clusterer + + +class _HelpMixin: + """Mixin class providing help for the `help` method and the `__repr__` method.""" + + def _get_methods_for_help(self): + """Get the methods to display in help.""" + methods = inspect.getmembers(self, predicate=inspect.ismethod) + filtered_methods = [] + for name, method in methods: + is_private_method = name.startswith("_") + # we cannot use `isinstance(method, classmethod)` because it is already + # transformed by the decorator `@classmethod`. + is_class_method = inspect.ismethod(method) and method.__self__ is type(self) + is_help_method = name == "help" + if not (is_private_method or is_class_method or is_help_method): + filtered_methods.append((name, method)) + return filtered_methods + + def _sort_methods_for_help(self, methods): + """Sort methods for help display.""" + return sorted(methods) + + def _format_method_name(self, name): + """Format method name for display.""" + return f"{name}(...)" + + def _get_method_description(self, method): + """Get the description for a method.""" + return ( + method.__doc__.split("\n")[0] + if method.__doc__ + else "No description available" + ) + + def _get_help_legend(self): + """Get the help legend.""" + return None + + def _create_help_panel(self): + """Create the help panel.""" + if self._get_help_legend(): + content = Group( + self._create_help_tree(), + f"\n\nLegend:\n{self._get_help_legend()}", + ) + else: + content = self._create_help_tree() + + return Panel( + content, + title=self._get_help_panel_title(), + expand=False, + border_style="orange1", + ) + + def help(self): + """Display available methods using rich.""" + from skore import console # avoid circular import + + console.print(self._create_help_panel()) + + def __repr__(self): + """Return a string representation using rich.""" + console = Console(file=StringIO(), force_terminal=False) + console.print(self._create_help_panel()) + return console.file.getvalue() + + +class _BaseAccessor(_HelpMixin): + """Base class for all accessors.""" + + def __init__(self, parent, icon): + self._parent = parent + self._icon = icon + + def _get_help_panel_title(self): + name = self.__class__.__name__.replace("_", "").replace("Accessor", "").lower() + return f"{self._icon} Available {name} methods" + + def _create_help_tree(self): + """Create a rich Tree with the available methods.""" + tree = Tree(self._get_help_tree_title()) + + methods = self._get_methods_for_help() + methods = self._sort_methods_for_help(methods) + + for name, method in methods: + displayed_name = self._format_method_name(name) + description = self._get_method_description(method) + tree.add(f".{displayed_name}".ljust(26) + f" - {description}") + + return tree + + def _get_X_y_and_data_source_hash(self, *, data_source, X=None, y=None): + """Get the requested dataset and mention if we should hash before caching. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features) or None, default=None + The input data. + + y : array-like of shape (n_samples,) or None, default=None + The target data. + + Returns + ------- + X : array-like of shape (n_samples, n_features) + The requested dataset. + + y : array-like of shape (n_samples,) + The requested dataset. + + data_source_hash : int or None + The hash of the data source. None when we are able to track the data, and + thus relying on X_train, y_train, X_test, y_test. + """ + is_cluster = is_clusterer(self._parent.estimator) + if data_source == "test": + if not (X is None or y is None): + raise ValueError("X and y must be None when data_source is test.") + if self._parent._X_test is None or ( + not is_cluster and self._parent._y_test is None + ): + missing_data = "X_test" if is_cluster else "X_test and y_test" + raise ValueError( + f"No {data_source} data (i.e. {missing_data}) were provided " + f"when creating the reporter. Please provide the {data_source} " + "data either when creating the reporter or by setting data_source " + "to 'X_y' and providing X and y." + ) + return self._parent._X_test, self._parent._y_test, None + elif data_source == "train": + if not (X is None or y is None): + raise ValueError("X and y must be None when data_source is train.") + if self._parent._X_train is None or ( + not is_cluster and self._parent._y_train is None + ): + missing_data = "X_train" if is_cluster else "X_train and y_train" + raise ValueError( + f"No {data_source} data (i.e. {missing_data}) were provided " + f"when creating the reporter. Please provide the {data_source} " + "data either when creating the reporter or by setting data_source " + "to 'X_y' and providing X and y." + ) + return self._parent._X_train, self._parent._y_train, None + elif data_source == "X_y": + if X is None or (not is_cluster and y is None): + missing_data = "X" if is_cluster else "X and y" + raise ValueError( + f"{missing_data} must be provided when data_source is X_y." + ) + return X, y, joblib.hash((X, y)) + else: + raise ValueError( + f"Invalid data source: {data_source}. Possible values are: " + "test, train, X_y." + ) diff --git a/skore/src/skore/sklearn/_estimator/base.pyi b/skore/src/skore/sklearn/_estimator/base.pyi new file mode 100644 index 000000000..eca62e85e --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/base.pyi @@ -0,0 +1,33 @@ +from typing import Any, Literal, Optional + +import numpy as np +from rich.panel import Panel +from rich.tree import Tree + +class _HelpMixin: + def _get_methods_for_help(self) -> list[tuple[str, Any]]: ... + def _sort_methods_for_help( + self, methods: list[tuple[str, Any]] + ) -> list[tuple[str, Any]]: ... + def _format_method_name(self, name: str) -> str: ... + def _get_method_description(self, method: Any) -> str: ... + def _create_help_panel(self) -> Panel: ... + def _get_help_panel_title(self) -> str: ... + def _create_help_tree(self) -> Tree: ... + def help(self) -> None: ... + def __repr__(self) -> str: ... + +class _BaseAccessor(_HelpMixin): + _parent: Any + _icon: str + + def __init__(self, parent: Any, icon: str) -> None: ... + def _get_help_panel_title(self) -> str: ... + def _create_help_tree(self) -> Tree: ... + def _get_X_y_and_data_source_hash( + self, + *, + data_source: Literal["test", "train", "X_y"], + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ) -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[str]]: ... diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.py b/skore/src/skore/sklearn/_estimator/metrics_accessor.py new file mode 100644 index 000000000..0f41595ac --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.py @@ -0,0 +1,1096 @@ +import inspect +from functools import partial + +import joblib +import numpy as np +import pandas as pd +from sklearn import metrics +from sklearn.metrics._scorer import _BaseScorer +from sklearn.utils.metaestimators import available_if + +from skore.externals._pandas_accessors import DirNamesMixin +from skore.sklearn._estimator.base import _BaseAccessor +from skore.sklearn._plot import ( + PrecisionRecallCurveDisplay, + PredictionErrorDisplay, + RocCurveDisplay, +) +from skore.utils._accessor import _check_supported_ml_task + +############################################################################### +# Metrics accessor +############################################################################### + + +class _MetricsAccessor(_BaseAccessor, DirNamesMixin): + """Accessor for metrics-related operations. + + You can access this accessor using the `metrics` attribute. + """ + + _SCORE_OR_LOSS_ICONS = { + "accuracy": "(↗︎)", + "precision": "(↗︎)", + "recall": "(↗︎)", + "brier_score": "(↘︎)", + "roc_auc": "(↗︎)", + "log_loss": "(↘︎)", + "r2": "(↗︎)", + "rmse": "(↘︎)", + "report_metrics": "", + "custom_metric": "", + } + + def __init__(self, parent): + super().__init__(parent, icon="📏") + + # TODO: should build on the `add_scorers` function + def report_metrics( + self, + *, + data_source="test", + X=None, + y=None, + scoring=None, + pos_label=None, + scoring_kwargs=None, + ): + """Report a set of metrics for our estimator. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + scoring : list of str, callable, or scorer, default=None + The metrics to report. You can get the possible list of string by calling + `reporter.metrics.help()`. When passing a callable, it should take as + arguments `y_true`, `y_pred` as the two first arguments. Additional + arguments can be passed as keyword arguments and will be forwarded with + `scoring_kwargs`. If the callable API is too restrictive (e.g. need to pass + same parameter name with different values), you can use scikit-learn scorers + as provided by :func:`sklearn.metrics.make_scorer`. + + pos_label : int, default=None + The positive class. + + scoring_kwargs : dict, default=None + The keyword arguments to pass to the scoring functions. + + Returns + ------- + pd.DataFrame + The statistics for the metrics. + """ + if scoring is None: + # Equivalent to _get_scorers_to_add + if self._parent._ml_task == "binary-classification": + scoring = ["precision", "recall", "roc_auc"] + if hasattr(self._parent._estimator, "predict_proba"): + scoring.append("brier_score") + elif self._parent._ml_task == "multiclass-classification": + scoring = ["precision", "recall"] + if hasattr(self._parent._estimator, "predict_proba"): + scoring += ["roc_auc", "log_loss"] + else: + scoring = ["r2", "rmse"] + + scores = [] + + for metric in scoring: + # NOTE: we have to check specifically for `_BaseScorer` first because this + # is also a callable but it has a special private API that we can leverage + if isinstance(metric, _BaseScorer): + # scorers have the advantage to have scoped defined kwargs + metric_fn = partial( + self.custom_metric, + metric_function=metric._score_func, + response_method=metric._response_method, + ) + # forward the additional parameters specific to the scorer + metrics_kwargs = {**metric._kwargs} + elif isinstance(metric, str) or callable(metric): + if isinstance(metric, str): + metric_fn = getattr(self, metric) + metrics_kwargs = {} + else: + metric_fn = partial(self.custom_metric, metric_function=metric) + if scoring_kwargs is None: + metrics_kwargs = {} + else: + # check if we should pass any parameters specific to the metric + # callable + metric_callable_params = inspect.signature(metric).parameters + metrics_kwargs = { + param: scoring_kwargs[param] + for param in metric_callable_params + if param in scoring_kwargs + } + metrics_params = inspect.signature(metric_fn).parameters + if scoring_kwargs is not None: + for param in metrics_params: + if param in scoring_kwargs: + metrics_kwargs[param] = scoring_kwargs[param] + if "pos_label" in metrics_params: + metrics_kwargs["pos_label"] = pos_label + else: + raise ValueError( + f"Invalid type of metric: {type(metric)} for {metric!r}" + ) + + scores.append( + metric_fn(data_source=data_source, X=X, y=y, **metrics_kwargs) + ) + + has_multilevel = any( + isinstance(score, pd.DataFrame) and isinstance(score.columns, pd.MultiIndex) + for score in scores + ) + + if has_multilevel: + # Convert single-level dataframes to multi-level + for i, score in enumerate(scores): + if hasattr(score, "columns") and not isinstance( + score.columns, pd.MultiIndex + ): + name_index = ( + ["Metric", "Output"] + if self._parent._ml_task == "regression" + else ["Metric", "Class label"] + ) + scores[i].columns = pd.MultiIndex.from_tuples( + [(col, "") for col in score.columns], + names=name_index, + ) + + return pd.concat(scores, axis=1) + + def _compute_metric_scores( + self, + metric_fn, + X, + y_true, + *, + data_source="test", + response_method, + pos_label=None, + metric_name=None, + **metric_kwargs, + ): + X, y_true, data_source_hash = self._get_X_y_and_data_source_hash( + data_source=data_source, X=X, y=y_true + ) + + y_pred = self._parent._get_cached_response_values( + estimator_hash=self._parent._hash, + estimator=self._parent.estimator, + X=X, + response_method=response_method, + pos_label=pos_label, + data_source=data_source, + data_source_hash=data_source_hash, + ) + cache_key = (self._parent._hash, metric_fn.__name__, data_source) + if data_source_hash: + cache_key += (data_source_hash,) + + metric_params = inspect.signature(metric_fn).parameters + if "pos_label" in metric_params: + cache_key += (pos_label,) + if metric_kwargs != {}: + # we need to enforce the order of the parameter for a specific metric + # to make sure that we hit the cache in a consistent way + ordered_metric_kwargs = sorted(metric_kwargs.keys()) + cache_key += tuple( + ( + joblib.hash(metric_kwargs[key]) + if isinstance(metric_kwargs[key], np.ndarray) + else metric_kwargs[key] + ) + for key in ordered_metric_kwargs + ) + + if cache_key in self._parent._cache: + score = self._parent._cache[cache_key] + else: + metric_params = inspect.signature(metric_fn).parameters + kwargs = {**metric_kwargs} + if "pos_label" in metric_params: + kwargs.update(pos_label=pos_label) + + score = metric_fn(y_true, y_pred, **kwargs) + self._parent._cache[cache_key] = score + + score = np.array([score]) if not isinstance(score, np.ndarray) else score + metric_name = metric_name or metric_fn.__name__ + + if self._parent._ml_task in [ + "binary-classification", + "multiclass-classification", + ]: + if len(score) == 1: + columns = pd.Index([metric_name], name="Metric") + else: + classes = self._parent._estimator.classes_ + columns = pd.MultiIndex.from_arrays( + [[metric_name] * len(classes), classes], + names=["Metric", "Class label"], + ) + score = score.reshape(1, -1) + elif self._parent._ml_task == "regression": + if len(score) == 1: + columns = pd.Index([metric_name], name="Metric") + else: + columns = pd.MultiIndex.from_arrays( + [ + [metric_name] * len(score), + [f"#{i}" for i in range(len(score))], + ], + names=["Metric", "Output"], + ) + score = score.reshape(1, -1) + else: + # FIXME: clusterer would fall here. + columns = None + return pd.DataFrame(score, columns=columns, index=[self._parent.estimator_name]) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def accuracy(self, *, data_source="test", X=None, y=None): + """Compute the accuracy score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + Returns + ------- + pd.DataFrame + The accuracy score. + """ + return self._compute_metric_scores( + metrics.accuracy_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + metric_name=f"Accuracy {self._SCORE_OR_LOSS_ICONS['accuracy']}", + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def precision( + self, *, data_source="test", X=None, y=None, average=None, pos_label=None + ): + """Compute the precision score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + average : {"binary","macro", "micro", "weighted", "samples"} or None, \ + default=None + Used with multiclass problems. + If `None`, the metrics for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + - "binary": Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + - "micro": Calculate metrics globally by counting the total true positives, + false negatives and false positives. + - "macro": Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + - "weighted": Calculate metrics for each label, and find their average + weighted by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an F-score + that is not between precision and recall. + - "samples": Calculate metrics for each instance, and find their average + (only meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + .. note:: + If `pos_label` is specified and `average` is None, then we report + only the statistics of the positive class (i.e. equivalent to + `average="binary"`). + + pos_label : int, default=None + The positive class. + + Returns + ------- + pd.DataFrame + The precision score. + """ + if self._parent._ml_task == "binary-classification" and pos_label is not None: + # if `pos_label` is specified by our user, then we can safely report only + # the statistics of the positive class + average = "binary" + + return self._compute_metric_scores( + metrics.precision_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + pos_label=pos_label, + metric_name=f"Precision {self._SCORE_OR_LOSS_ICONS['precision']}", + average=average, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def recall( + self, *, data_source="test", X=None, y=None, average=None, pos_label=None + ): + """Compute the recall score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + average : {"binary","macro", "micro", "weighted", "samples"} or None, \ + default=None + Used with multiclass problems. + If `None`, the metrics for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + - "binary": Only report results for the class specified by `pos_label`. + This is applicable only if targets (`y_{true,pred}`) are binary. + - "micro": Calculate metrics globally by counting the total true positives, + false negatives and false positives. + - "macro": Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + - "weighted": Calculate metrics for each label, and find their average + weighted by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an F-score + that is not between precision and recall. Weighted recall is equal to + accuracy. + - "samples": Calculate metrics for each instance, and find their average + (only meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + .. note:: + If `pos_label` is specified and `average` is None, then we report + only the statistics of the positive class (i.e. equivalent to + `average="binary"`). + + pos_label : int, default=None + The positive class. + + Returns + ------- + pd.DataFrame + The recall score. + """ + if self._parent._ml_task == "binary-classification" and pos_label is not None: + # if `pos_label` is specified by our user, then we can safely report only + # the statistics of the positive class + average = "binary" + + return self._compute_metric_scores( + metrics.recall_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + pos_label=pos_label, + metric_name=f"Recall {self._SCORE_OR_LOSS_ICONS['recall']}", + average=average, + ) + + @available_if( + _check_supported_ml_task(supported_ml_tasks=["binary-classification"]) + ) + def brier_score(self, *, data_source="test", X=None, y=None): + """Compute the Brier score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + Returns + ------- + pd.DataFrame + The Brier score. + """ + # The Brier score in scikit-learn request `pos_label` to ensure that the + # integral encoding of `y_true` corresponds to the probabilities of the + # `pos_label`. Since we get the predictions with `get_response_method`, we + # can pass any `pos_label`, they will lead to the same result. + return self._compute_metric_scores( + metrics.brier_score_loss, + X=X, + y_true=y, + data_source=data_source, + response_method="predict_proba", + metric_name=f"Brier score {self._SCORE_OR_LOSS_ICONS['brier_score']}", + pos_label=self._parent._estimator.classes_[-1], + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def roc_auc( + self, *, data_source="test", X=None, y=None, average=None, multi_class="ovr" + ): + """Compute the ROC AUC score. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + average : {"auto", "macro", "micro", "weighted", "samples"}, \ + default=None + Average to compute the ROC AUC score in a multiclass setting. By default, + no average is computed. Otherwise, this determines the type of averaging + performed on the data. + + - "micro": Calculate metrics globally by considering each element of + the label indicator matrix as a label. + - "macro": Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + - "weighted": Calculate metrics for each label, and find their average, + weighted by support (the number of true instances for each label). + - "samples": Calculate metrics for each instance, and find their + average. + + .. note:: + Multiclass ROC AUC currently only handles the "macro" and + "weighted" averages. For multiclass targets, `average=None` is only + implemented for `multi_class="ovr"` and `average="micro"` is only + implemented for `multi_class="ovr"`. + + multi_class : {"raise", "ovr", "ovo"}, default="ovr" + The multi-class strategy to use. + + - "raise": Raise an error if the data is multiclass. + - "ovr": Stands for One-vs-rest. Computes the AUC of each class against the + rest. This treats the multiclass case in the same way as the multilabel + case. Sensitive to class imbalance even when `average == "macro"`, + because class imbalance affects the composition of each of the "rest" + groupings. + - "ovo": Stands for One-vs-one. Computes the average AUC of all possible + pairwise combinations of classes. Insensitive to class imbalance when + `average == "macro"`. + + Returns + ------- + pd.DataFrame + The ROC AUC score. + """ + return self._compute_metric_scores( + metrics.roc_auc_score, + X=X, + y_true=y, + data_source=data_source, + response_method=["predict_proba", "decision_function"], + metric_name=f"ROC AUC {self._SCORE_OR_LOSS_ICONS['roc_auc']}", + average=average, + multi_class=multi_class, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def log_loss(self, *, data_source="test", X=None, y=None): + """Compute the log loss. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + Returns + ------- + pd.DataFrame + The log-loss. + """ + return self._compute_metric_scores( + metrics.log_loss, + X=X, + y_true=y, + data_source=data_source, + response_method="predict_proba", + metric_name=f"Log loss {self._SCORE_OR_LOSS_ICONS['log_loss']}", + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def r2(self, *, data_source="test", X=None, y=None, multioutput="raw_values"): + """Compute the R² score. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + multioutput : {"raw_values", "uniform_average"} or array-like of shape \ + (n_outputs,), default="raw_values" + Defines aggregating of multiple output values. Array-like value defines + weights used to average errors. The other possible values are: + + - "raw_values": Returns a full set of errors in case of multioutput input. + - "uniform_average": Errors of all outputs are averaged with uniform weight. + + By default, no averaging is done. + + Returns + ------- + pd.DataFrame + The R² score. + """ + return self._compute_metric_scores( + metrics.r2_score, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + metric_name=f"R² {self._SCORE_OR_LOSS_ICONS['r2']}", + multioutput=multioutput, + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def rmse(self, *, data_source="test", X=None, y=None, multioutput="raw_values"): + """Compute the root mean squared error. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + multioutput : {"raw_values", "uniform_average"} or array-like of shape \ + (n_outputs,), default="raw_values" + Defines aggregating of multiple output values. Array-like value defines + weights used to average errors. The other possible values are: + + - "raw_values": Returns a full set of errors in case of multioutput input. + - "uniform_average": Errors of all outputs are averaged with uniform weight. + + By default, no averaging is done. + + Returns + ------- + pd.DataFrame + The root mean squared error. + """ + return self._compute_metric_scores( + metrics.root_mean_squared_error, + X=X, + y_true=y, + data_source=data_source, + response_method="predict", + metric_name=f"RMSE {self._SCORE_OR_LOSS_ICONS['rmse']}", + multioutput=multioutput, + ) + + def custom_metric( + self, + metric_function, + response_method, + *, + metric_name=None, + data_source="test", + X=None, + y=None, + **kwargs, + ): + """Compute a custom metric. + + It brings some flexibility to compute any desired metric. However, we need to + follow some rules: + + - `metric_function` should take `y_true` and `y_pred` as the first two + positional arguments. + - `response_method` corresponds to the estimator's method to be invoked to get + the predictions. It can be a string or a list of strings to defined in which + order the methods should be invoked. + + Parameters + ---------- + metric_function : callable + The metric function to be computed. The expected signature is + `metric_function(y_true, y_pred, **kwargs)`. + + response_method : str or list of str + The estimator's method to be invoked to get the predictions. The possible + values are: `predict`, `predict_proba`, `predict_log_proba`, and + `decision_function`. + + metric_name : str, default=None + The name of the metric. If not provided, it will be inferred from the + metric function. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + **kwargs : dict + Any additional keyword arguments to be passed to the metric function. + + Returns + ------- + pd.DataFrame + The custom metric. + """ + return self._compute_metric_scores( + metric_function, + X=X, + y_true=y, + data_source=data_source, + response_method=response_method, + metric_name=metric_name, + **kwargs, + ) + + #################################################################################### + # Methods related to the help tree + #################################################################################### + + def _sort_methods_for_help(self, methods): + """Override sort method for metrics-specific ordering. + + In short, we display the `report_metrics` first and then the `custom_metric`. + """ + + def _sort_key(method): + name = method[0] + if name == "custom_metric": + priority = 1 + elif name == "report_metrics": + priority = 2 + else: + priority = 0 + return priority, name + + return sorted(methods, key=_sort_key) + + def _format_method_name(self, name): + """Override format method for metrics-specific naming.""" + method_name = f"{name}(...)" + method_name = method_name.ljust(22) + if self._SCORE_OR_LOSS_ICONS[name] in ("(↗︎)", "(↘︎)"): + if self._SCORE_OR_LOSS_ICONS[name] == "(↗︎)": + method_name += f"[cyan]{self._SCORE_OR_LOSS_ICONS[name]}[/cyan]" + return method_name.ljust(43) + else: # (↘︎) + method_name += f"[orange1]{self._SCORE_OR_LOSS_ICONS[name]}[/orange1]" + return method_name.ljust(49) + else: + return method_name.ljust(29) + + def _get_methods_for_help(self): + """Override to exclude the plot accessor from methods list.""" + methods = super()._get_methods_for_help() + return [(name, method) for name, method in methods if name != "plot"] + + def _create_help_tree(self): + """Override to include plot methods in a separate branch.""" + tree = super()._create_help_tree() + + # Add plot methods in a separate branch + plot_branch = tree.add("[bold cyan].plot 🎨[/bold cyan]") + plot_methods = self.plot._get_methods_for_help() + plot_methods = self.plot._sort_methods_for_help(plot_methods) + + for name, method in plot_methods: + displayed_name = self.plot._format_method_name(name) + description = self.plot._get_method_description(method) + plot_branch.add(f".{displayed_name}".ljust(27) + f"- {description}") + + return tree + + def _get_help_panel_title(self): + return f"[bold cyan]{self._icon} Available metrics methods[/bold cyan]" + + def _get_help_legend(self): + return ( + "[cyan](↗︎)[/cyan] higher is better [orange1](↘︎)[/orange1] lower is better" + ) + + def _get_help_tree_title(self): + return "[bold cyan]reporter.metrics[/bold cyan]" + + +######################################################################################## +# Sub-accessors +# Plotting +######################################################################################## + + +class _PlotMetricsAccessor(_BaseAccessor): + """Plotting methods for the metrics accessor.""" + + def __init__(self, parent): + super().__init__(parent._parent, icon="🎨") + self._metrics_parent = parent + + def _get_display( + self, + *, + X, + y, + data_source, + response_method, + display_class, + display_kwargs, + display_plot_kwargs, + ): + """Get the display from the cache or compute it. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data. + + y : array-like of shape (n_samples,) + The target. + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + response_method : str + The response method. + + display_class : class + The display class. + + display_kwargs : dict + The display kwargs used by `display_class._from_predictions`. + + display_plot_kwargs : dict + The display kwargs used by `display.plot`. + + Returns + ------- + display : display_class + The display. + """ + X, y, data_source_hash = self._get_X_y_and_data_source_hash( + data_source=data_source, X=X, y=y + ) + + cache_key = (self._parent._hash, display_class.__name__) + cache_key += tuple(display_kwargs.values()) + cache_key += (data_source_hash,) if data_source_hash else (data_source,) + + if cache_key in self._parent._cache: + display = self._parent._cache[cache_key] + display.plot(**display_plot_kwargs) + else: + y_pred = self._parent._get_cached_response_values( + estimator_hash=self._parent._hash, + estimator=self._parent.estimator, + X=X, + response_method=response_method, + data_source=data_source, + data_source_hash=data_source_hash, + pos_label=display_kwargs.get("pos_label", None), + ) + + display = display_class._from_predictions( + y, + y_pred, + estimator=self._parent.estimator, + estimator_name=self._parent.estimator_name, + ml_task=self._parent._ml_task, + data_source=data_source, + **display_kwargs, + **display_plot_kwargs, + ) + self._parent._cache[cache_key] = display + + return display + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def roc(self, *, data_source="test", X=None, y=None, pos_label=None, ax=None): + """Plot the ROC curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + pos_label : str, default=None + The positive class. + + ax : matplotlib.axes.Axes, default=None + The axes to plot on. + + Returns + ------- + RocCurveDisplay + The ROC curve display. + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = {"pos_label": pos_label} + display_plot_kwargs = {"ax": ax, "plot_chance_level": True, "despine": True} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=RocCurveDisplay, + display_kwargs=display_kwargs, + display_plot_kwargs=display_plot_kwargs, + ) + + @available_if( + _check_supported_ml_task( + supported_ml_tasks=["binary-classification", "multiclass-classification"] + ) + ) + def precision_recall( + self, + *, + data_source="test", + X=None, + y=None, + pos_label=None, + ax=None, + ): + """Plot the precision-recall curve. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + pos_label : str, default=None + The positive class. + + ax : matplotlib.axes.Axes, default=None + The axes to plot on. + + Returns + ------- + PrecisionRecallCurveDisplay + The precision-recall curve display. + """ + response_method = ("predict_proba", "decision_function") + display_kwargs = {"pos_label": pos_label} + display_plot_kwargs = {"ax": ax, "plot_chance_level": False, "despine": True} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method=response_method, + display_class=PrecisionRecallCurveDisplay, + display_kwargs=display_kwargs, + display_plot_kwargs=display_plot_kwargs, + ) + + @available_if(_check_supported_ml_task(supported_ml_tasks=["regression"])) + def prediction_error( + self, + *, + data_source="test", + X=None, + y=None, + ax=None, + kind="residual_vs_predicted", + subsample=1_000, + ): + """Plot the prediction error of a regression model. + + Extra keyword arguments will be passed to matplotlib's `plot`. + + Parameters + ---------- + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + X : array-like of shape (n_samples, n_features), default=None + New data on which to compute the metric. By default, we use the validation + set provided when creating the reporter. + + y : array-like of shape (n_samples,), default=None + New target on which to compute the metric. By default, we use the target + provided when creating the reporter. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, 1,000 samples or less will be displayed. + + Returns + ------- + PredictionErrorDisplay + The prediction error display. + """ + display_kwargs = {"kind": kind, "subsample": subsample} + display_plot_kwargs = {"ax": ax} + return self._get_display( + X=X, + y=y, + data_source=data_source, + response_method="predict", + display_class=PredictionErrorDisplay, + display_kwargs=display_kwargs, + display_plot_kwargs=display_plot_kwargs, + ) + + def _get_help_panel_title(self): + return f"[bold cyan]{self._icon} Available plot methods[/bold cyan]" + + def _get_help_tree_title(self): + return "[bold cyan]reporter.metrics.plot[/bold cyan]" diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi b/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi new file mode 100644 index 000000000..c73d136f1 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi @@ -0,0 +1,168 @@ +from typing import Any, Callable, Literal, Optional, Union + +import matplotlib.axes +import numpy as np +import pandas as pd +from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay + +from skore.sklearn._estimator.base import _BaseAccessor +from skore.sklearn._plot import PredictionErrorDisplay + +class _PlotMetricsAccessor(_BaseAccessor): + _metrics_parent: _MetricsAccessor + + def __init__(self, parent: _MetricsAccessor) -> None: ... + def _get_display( + self, + *, + X: Optional[np.ndarray], + y: Optional[np.ndarray], + data_source: Literal["test", "train", "X_y"], + response_method: Union[str, list[str]], + display_class: Any, + display_kwargs: dict[str, Any], + display_plot_kwargs: dict[str, Any], + ) -> Union[RocCurveDisplay, PrecisionRecallDisplay, PredictionErrorDisplay]: ... + def roc( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + pos_label: Optional[Union[str, int]] = None, + ax: Optional[matplotlib.axes.Axes] = None, + ) -> RocCurveDisplay: ... + def precision_recall( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + pos_label: Optional[Union[str, int]] = None, + ax: Optional[matplotlib.axes.Axes] = None, + ) -> PrecisionRecallDisplay: ... + def prediction_error( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ax: Optional[matplotlib.axes.Axes] = None, + kind: Literal[ + "actual_vs_predicted", "residual_vs_predicted" + ] = "residual_vs_predicted", + subsample: Optional[Union[int, float]] = 1_000, + ) -> PredictionErrorDisplay: ... + +class _MetricsAccessor(_BaseAccessor): + _SCORE_OR_LOSS_ICONS: dict[str, str] + plot: _PlotMetricsAccessor + + def _compute_metric_scores( + self, + metric_fn: Callable, + X: Optional[np.ndarray], + y_true: Optional[np.ndarray], + *, + data_source: Literal["test", "train", "X_y"] = "test", + response_method: Union[str, list[str]], + pos_label: Optional[Union[str, int]] = None, + metric_name: Optional[str] = None, + **metric_kwargs: Any, + ) -> pd.DataFrame: ... + def report_metrics( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + scoring: Optional[Union[list[str], Callable]] = None, + pos_label: Optional[Union[str, int]] = None, + scoring_kwargs: Optional[dict[str, Any]] = None, + ) -> pd.DataFrame: ... + def accuracy( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ) -> pd.DataFrame: ... + def precision( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + average: Optional[ + Literal["binary", "micro", "macro", "weighted", "samples"] + ] = None, + pos_label: Optional[Union[str, int]] = None, + ) -> pd.DataFrame: ... + def recall( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + average: Optional[ + Literal["binary", "micro", "macro", "weighted", "samples"] + ] = None, + pos_label: Optional[Union[str, int]] = None, + ) -> pd.DataFrame: ... + def brier_score( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + pos_label: Optional[Union[str, int]] = None, + ) -> pd.DataFrame: ... + def roc_auc( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + average: Optional[ + Literal["auto", "micro", "macro", "weighted", "samples"] + ] = None, + multi_class: Literal["raise", "ovr", "ovo"] = "ovr", + ) -> pd.DataFrame: ... + def log_loss( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + ) -> pd.DataFrame: ... + def r2( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + multioutput: Union[ + Literal["raw_values", "uniform_average"], np.ndarray + ] = "raw_values", + ) -> pd.DataFrame: ... + def rmse( + self, + *, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + multioutput: Union[ + Literal["raw_values", "uniform_average"], np.ndarray + ] = "raw_values", + ) -> pd.DataFrame: ... + def custom_metric( + self, + metric_function: Callable, + response_method: Union[str, list[str]], + *, + metric_name: Optional[str] = None, + data_source: Literal["test", "train", "X_y"] = "test", + X: Optional[np.ndarray] = None, + y: Optional[np.ndarray] = None, + **kwargs: Any, + ) -> pd.DataFrame: ... diff --git a/skore/src/skore/sklearn/_estimator/report.py b/skore/src/skore/sklearn/_estimator/report.py new file mode 100644 index 000000000..35b4e6425 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/report.py @@ -0,0 +1,413 @@ +import inspect +import time +from itertools import product + +import joblib +import numpy as np +from rich.progress import track +from rich.tree import Tree +from sklearn.base import clone +from sklearn.exceptions import NotFittedError +from sklearn.pipeline import Pipeline +from sklearn.utils._response import _check_response_method, _get_response_values +from sklearn.utils.validation import check_is_fitted + +from skore.externals._pandas_accessors import DirNamesMixin +from skore.externals._sklearn_compat import is_clusterer +from skore.sklearn._estimator.base import _BaseAccessor, _HelpMixin +from skore.sklearn.find_ml_task import _find_ml_task + + +class EstimatorReport(_HelpMixin, DirNamesMixin): + """Report for a fitted estimator. + + This class provides a set of tools to quickly validate and inspect a scikit-learn + compatible estimator. + + Parameters + ---------- + estimator : estimator object + Estimator to make report from. + + fit : {"auto", True, False}, default="auto" + Whether to fit the estimator on the training data. If "auto", the estimator + is fitted only if the training data is provided. + + X_train : {array-like, sparse matrix} of shape (n_samples, n_features) or \ + None + Training data. + + y_train : array-like of shape (n_samples,) or (n_samples, n_outputs) or None + Training target. + + X_test : {array-like, sparse matrix} of shape (n_samples, n_features) or None + Testing data. It should have the same structure as the training data. + + y_test : array-like of shape (n_samples,) or (n_samples, n_outputs) or None + Testing target. + + Attributes + ---------- + metrics : _MetricsAccessor + Accessor for metrics-related operations. + + Examples + -------- + >>> from sklearn.datasets import make_classification + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=42) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + >>> estimator = LogisticRegression().fit(X_train, y_train) + >>> from skore import EstimatorReport + >>> report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + """ + + _ACCESSOR_CONFIG = { + "metrics": {"icon": "📏", "name": "metrics"}, + # Add other accessors as they're implemented + # "inspection": {"icon": "🔍", "name": "inspection"}, + # "linting": {"icon": "✔️", "name": "linting"}, + } + + @staticmethod + def _fit_estimator(estimator, X_train, y_train): + if X_train is None or (y_train is None and not is_clusterer(estimator)): + raise ValueError( + "The training data is required to fit the estimator. " + "Please provide both X_train and y_train." + ) + return clone(estimator).fit(X_train, y_train) + + def __init__( + self, + estimator, + *, + fit="auto", + X_train=None, + y_train=None, + X_test=None, + y_test=None, + ): + if fit == "auto": + try: + check_is_fitted(estimator) + self._estimator = estimator + except NotFittedError: + self._estimator = self._fit_estimator(estimator, X_train, y_train) + elif fit is True: + self._estimator = self._fit_estimator(estimator, X_train, y_train) + else: # fit is False + self._estimator = estimator + + # private storage to be able to invalidate the cache when the user alters + # those attributes + self._X_train = X_train + self._y_train = y_train + self._X_test = X_test + self._y_test = y_test + + self._initialize_state() + + def _initialize_state(self): + """Initialize/reset the random number generator, hash, and cache.""" + self._rng = np.random.default_rng(time.time_ns()) + self._hash = self._rng.integers( + low=np.iinfo(np.int64).min, high=np.iinfo(np.int64).max + ) + self._cache = {} + self._ml_task = _find_ml_task(self._y_test, estimator=self._estimator) + + # NOTE: + # For the moment, we do not allow to alter the estimator and the training data. + # For the validation set, we allow it and we invalidate the cache. + + def clean_cache(self): + """Clean the cache.""" + self._cache = {} + + def cache_predictions(self, response_methods="auto", n_jobs=None): + """Force caching of estimator's predictions. + + Parameters + ---------- + response_methods : "auto" or list of str, default="auto" + The response methods to precompute. If "auto", the response methods are + inferred from the ml task: for classification we compute the response of + the `predict_proba`, `decision_function` and `predict` methods; for + regression we compute the response of the `predict` method. + + n_jobs : int or None, default=None + The number of jobs to run in parallel. None means 1 unless in a + joblib.parallel_backend context. -1 means using all processors. + """ + if self._ml_task in ("binary-classification", "multiclass-classification"): + if response_methods == "auto": + response_methods = ("predict",) + if hasattr(self._estimator, "predict_proba"): + response_methods = ("predict_proba",) + if hasattr(self._estimator, "decision_function"): + response_methods = ("decision_function",) + pos_labels = self._estimator.classes_ + else: + if response_methods == "auto": + response_methods = ("predict",) + pos_labels = [None] + + data_sources = ("test",) + Xs = (self._X_test,) + if self._X_train is not None: + data_sources = ("train",) + Xs = (self._X_train,) + + parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered") + generator = parallel( + joblib.delayed(self._get_cached_response_values)( + estimator_hash=self._hash, + estimator=self._estimator, + X=X, + response_method=response_method, + pos_label=pos_label, + data_source=data_source, + ) + for response_method, pos_label, data_source, X in product( + response_methods, pos_labels, data_sources, Xs + ) + ) + # trigger the computation + list( + track( + generator, + total=len(response_methods) * len(pos_labels) * len(data_sources), + description="Caching predictions", + ) + ) + + @property + def estimator(self): + return self._estimator + + @estimator.setter + def estimator(self, value): + raise AttributeError( + "The estimator attribute is immutable. " + "Call the constructor of {self.__class__.__name__} to create a new report." + ) + + @property + def X_train(self): + return self._X_train + + @X_train.setter + def X_train(self, value): + raise AttributeError( + "The X_train attribute is immutable. " + "Please use the `from_unfitted_estimator` method to create a new report." + ) + + @property + def y_train(self): + return self._y_train + + @y_train.setter + def y_train(self, value): + raise AttributeError( + "The y_train attribute is immutable. " + "Please use the `from_unfitted_estimator` method to create a new report." + ) + + @property + def X_test(self): + return self._X_test + + @X_test.setter + def X_test(self, value): + self._X_test = value + self._initialize_state() + + @property + def y_test(self): + return self._y_test + + @y_test.setter + def y_test(self, value): + self._y_test = value + self._initialize_state() + + @property + def estimator_name(self): + if isinstance(self._estimator, Pipeline): + name = self._estimator[-1].__class__.__name__ + else: + name = self._estimator.__class__.__name__ + return name + + def _get_cached_response_values( + self, + *, + estimator_hash, + estimator, + X, + response_method, + pos_label=None, + data_source="test", + data_source_hash=None, + ): + """Compute or load from local cache the response values. + + Parameters + ---------- + estimator_hash : int + A hash associated with the estimator such that we can retrieve the data from + the cache. + + estimator : estimator object + The estimator. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The data. + + response_method : str + The response method. + + pos_label : str, default=None + The positive label. + + data_source : {"test", "train", "X_y"}, default="test" + The data source to use. + + - "test" : use the test set provided when creating the reporter. + - "train" : use the train set provided when creating the reporter. + - "X_y" : use the provided `X` and `y` to compute the metric. + + data_source_hash : int or None + The hash of the data source when `data_source` is "X_y". + + Returns + ------- + array-like of shape (n_samples,) or (n_samples, n_outputs) + The response values. + """ + prediction_method = _check_response_method(estimator, response_method).__name__ + if prediction_method in ("predict_proba", "decision_function"): + # pos_label is only important in classification and with probabilities + # and decision functions + cache_key = (estimator_hash, pos_label, prediction_method, data_source) + else: + cache_key = (estimator_hash, prediction_method, data_source) + + if data_source == "X_y": + data_source_hash = joblib.hash(X) + cache_key += (data_source_hash,) + + if cache_key in self._cache: + return self._cache[cache_key] + + predictions, _ = _get_response_values( + estimator, + X=X, + response_method=prediction_method, + pos_label=pos_label, + return_response_method_used=False, + ) + self._cache[cache_key] = predictions + + return predictions + + #################################################################################### + # Methods related to the help tree + #################################################################################### + + def _get_help_panel_title(self): + return ( + f"[bold cyan]📓 Tools to diagnose estimator " + f"{self.estimator_name}[/bold cyan]" + ) + + def _get_help_legend(self): + return ( + "[cyan](↗︎)[/cyan] higher is better [orange1](↘︎)[/orange1] lower is better" + ) + + def _get_attributes_for_help(self): + """Get the public attributes to display in help.""" + attributes = [] + xy_attributes = [] + + for name in dir(self): + # Skip private attributes, callables, and accessors + if ( + name.startswith("_") + or callable(getattr(self, name)) + or isinstance(getattr(self, name), _BaseAccessor) + ): + continue + + # Group X and y attributes separately + value = getattr(self, name) + if name.startswith(("X_", "y_")): + if value is not None: # Only include non-None X/y attributes + xy_attributes.append(name) + else: + attributes.append(name) + + # Sort X/y attributes to keep them grouped + xy_attributes.sort() + attributes.sort() + + # Return X/y attributes first, followed by other attributes + return xy_attributes + attributes + + def _create_help_tree(self): + """Create a rich Tree with the available tools and accessor methods.""" + tree = Tree("reporter") + + # Add accessor methods first + for accessor_attr, config in self._ACCESSOR_CONFIG.items(): + accessor = getattr(self, accessor_attr) + branch = tree.add( + f"[bold cyan].{config['name']} {config['icon']}[/bold cyan]" + ) + + # Add main accessor methods first + methods = accessor._get_methods_for_help() + methods = accessor._sort_methods_for_help(methods) + + # Add methods + for name, method in methods: + displayed_name = accessor._format_method_name(name) + description = accessor._get_method_description(method) + branch.add(f".{displayed_name} - {description}") + + # Add sub-accessors after main methods + for sub_attr, sub_obj in inspect.getmembers(accessor): + if isinstance(sub_obj, _BaseAccessor) and not sub_attr.startswith("_"): + sub_branch = branch.add( + f"[bold cyan].{sub_attr} {sub_obj._icon}[/bold cyan]" + ) + + # Add sub-accessor methods + sub_methods = sub_obj._get_methods_for_help() + sub_methods = sub_obj._sort_methods_for_help(sub_methods) + + for name, method in sub_methods: + displayed_name = sub_obj._format_method_name(name) + description = sub_obj._get_method_description(method) + sub_branch.add(f".{displayed_name.ljust(25)} - {description}") + + # Add base methods + base_methods = self._get_methods_for_help() + base_methods = self._sort_methods_for_help(base_methods) + + for name, method in base_methods: + description = self._get_method_description(method) + tree.add(f".{name}(...)".ljust(34) + f" - {description}") + + # Add attributes section + attributes = self._get_attributes_for_help() + if attributes: + attr_branch = tree.add("[bold cyan]Attributes[/bold cyan]") + for attr in attributes: + attr_branch.add(f".{attr}") + + return tree diff --git a/skore/src/skore/sklearn/_estimator/report.pyi b/skore/src/skore/sklearn/_estimator/report.pyi new file mode 100644 index 000000000..74c4c215d --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/report.pyi @@ -0,0 +1,71 @@ +from typing import Any, Literal, Optional, Union + +import numpy as np +from sklearn.base import BaseEstimator + +from skore.sklearn._estimator.base import _HelpMixin +from skore.sklearn._estimator.metrics_accessor import _MetricsAccessor + +class EstimatorReport(_HelpMixin): + _ACCESSOR_CONFIG: dict[str, dict[str, str]] + _estimator: BaseEstimator + _X_train: Optional[np.ndarray] + _y_train: Optional[np.ndarray] + _X_test: Optional[np.ndarray] + _y_test: Optional[np.ndarray] + _rng: np.random.Generator + _hash: int + _cache: dict[Any, Any] + _ml_task: str + metrics: _MetricsAccessor + + @staticmethod + def _fit_estimator( + estimator: BaseEstimator, X_train: np.ndarray, y_train: Optional[np.ndarray] + ) -> BaseEstimator: ... + def __init__( + self, + estimator: BaseEstimator, + *, + fit: Literal["auto", True, False] = "auto", + X_train: Optional[np.ndarray] = None, + y_train: Optional[np.ndarray] = None, + X_test: Optional[np.ndarray] = None, + y_test: Optional[np.ndarray] = None, + ) -> None: ... + def _initialize_state(self) -> None: ... + def clean_cache(self) -> None: ... + def cache_predictions( + self, + response_methods: Union[Literal["auto"], list[str]] = "auto", + n_jobs: Optional[int] = None, + ) -> None: ... + @property + def estimator(self) -> BaseEstimator: ... + @property + def X_train(self) -> Optional[np.ndarray]: ... + @property + def y_train(self) -> Optional[np.ndarray]: ... + @property + def X_test(self) -> Optional[np.ndarray]: ... + @X_test.setter + def X_test(self, value: Optional[np.ndarray]) -> None: ... + @property + def y_test(self) -> Optional[np.ndarray]: ... + @y_test.setter + def y_test(self, value: Optional[np.ndarray]) -> None: ... + @property + def estimator_name(self) -> str: ... + def _get_cached_response_values( + self, + *, + estimator_hash: int, + estimator: BaseEstimator, + X: np.ndarray, + response_method: Union[str, list[str]], + pos_label: Optional[Union[str, int]] = None, + data_source: Literal["test", "train", "X_y"] = "test", + data_source_hash: Optional[str] = None, + ) -> np.ndarray: ... + def _get_help_panel_title(self) -> str: ... + def _create_help_tree(self) -> Any: ... # Returns rich.tree.Tree diff --git a/skore/src/skore/sklearn/_estimator/utils.py b/skore/src/skore/sklearn/_estimator/utils.py new file mode 100644 index 000000000..578fab610 --- /dev/null +++ b/skore/src/skore/sklearn/_estimator/utils.py @@ -0,0 +1,19 @@ +from sklearn.pipeline import Pipeline + + +def _check_supported_estimator(supported_estimators): + def check(accessor): + estimator = accessor._parent.estimator + if isinstance(estimator, Pipeline): + estimator = estimator.steps[-1][1] + supported_estimator = isinstance(estimator, supported_estimators) + + if not supported_estimator: + raise AttributeError( + f"The {estimator.__class__.__name__} estimator is not supported " + "by the function called." + ) + + return True + + return check diff --git a/skore/src/skore/sklearn/_plot/__init__.py b/skore/src/skore/sklearn/_plot/__init__.py new file mode 100644 index 000000000..7f39733e4 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/__init__.py @@ -0,0 +1,9 @@ +from skore.sklearn._plot.precision_recall_curve import PrecisionRecallCurveDisplay +from skore.sklearn._plot.prediction_error import PredictionErrorDisplay +from skore.sklearn._plot.roc_curve import RocCurveDisplay + +__all__ = [ + "RocCurveDisplay", + "PrecisionRecallCurveDisplay", + "PredictionErrorDisplay", +] diff --git a/skore/src/skore/sklearn/_plot/precision_recall_curve.py b/skore/src/skore/sklearn/_plot/precision_recall_curve.py new file mode 100644 index 000000000..96881362b --- /dev/null +++ b/skore/src/skore/sklearn/_plot/precision_recall_curve.py @@ -0,0 +1,511 @@ +from collections import Counter + +from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.preprocessing import LabelBinarizer + +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _ClassifierCurveDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) + + +class PrecisionRecallCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin): + """Precision Recall visualization. + + An instance of this class is should created by + `EstimatorReport.metrics.plot.precision_recall()`. You should not create an + instance of this class directly. + + + Parameters + ---------- + precision : dict of list of ndarray + Precision values. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the precision. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the precision. + + recall : dict of list of ndarray + Recall values. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the recall. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the recall. + + average_precision : dict of list of float + Average precision. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `float`, each `float` being the average + precision. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the average + precision. + + prevalence : dict of list of float + The prevalence of the positive label. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `float`, each `float` being the prevalence. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the prevalence. + + estimator_name : str + Name of the estimator. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class. If None, the class will not + be shown in the legend. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the precision recall curve. + + Attributes + ---------- + ax_ : matplotlib Axes + Axes with precision recall curve. + + figure_ : matplotlib Figure + Figure containing the curve. + + lines_ : list of matplotlib Artist + Precision recall curve. + + chance_levels_ : matplotlib Artist or None + The chance level line. It is `None` if the chance level is not plotted. + """ + + def __init__( + self, + precision, + recall, + *, + average_precision, + prevalence, + estimator_name, + pos_label=None, + data_source=None, + ): + self.precision = precision + self.recall = recall + self.average_precision = average_precision + self.prevalence = prevalence + self.estimator_name = estimator_name + self.pos_label = pos_label + self.data_source = data_source + + def plot( + self, + ax=None, + *, + estimator_name=None, + pr_curve_kwargs=None, + plot_chance_level=False, + chance_level_kwargs=None, + despine=True, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's `plot`. + + Parameters + ---------- + ax : Matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str, default=None + Name of the estimator used to plot the precision-recall curve. If + `None`, we use the inferred name from the estimator. + + plot_chance_level : bool, default=True + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. + + pr_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the precision-recall curve(s). + + chance_level_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : PrecisionRecallCurveDisplay + Object that stores computed values. + + Notes + ----- + The average precision (cf. :func:`~sklearn.metrics.average_precision_score`) + in scikit-learn is computed without any interpolation. To be consistent + with this metric, the precision-recall curve is plotted without any + interpolation as well (step-wise style). + + You can change this style by passing the keyword argument + `drawstyle="default"`. However, the curve will not be strictly + consistent with the reported average precision. + """ + self.ax_, self.figure_, estimator_name = self._validate_plot_params( + ax=ax, estimator_name=estimator_name + ) + + self.lines_ = [] + self.chance_levels_ = [] + if len(self.precision) == 1: # binary-classification + if len(self.precision[self.pos_label]) == 1: # single-split + if pr_curve_kwargs is None: + pr_curve_kwargs = {} + elif isinstance(pr_curve_kwargs, list): + if len(pr_curve_kwargs) > 1: + raise ValueError( + "You intend to plot a single precision-recall curve and " + "provide multiple precision-recall curve keyword " + "arguments. Provide a single dictionary or a list with " + "a single dictionary." + ) + pr_curve_kwargs = pr_curve_kwargs[0] + + precision = self.precision[self.pos_label][0] + recall = self.recall[self.pos_label][0] + average_precision = self.average_precision[self.pos_label][0] + prevalence = self.prevalence[self.pos_label][0] + + default_line_kwargs = {"drawstyle": "steps-post"} + if average_precision is not None and self.data_source in ( + "train", + "test", + ): + default_line_kwargs["label"] = ( + f"{self.data_source.title()} set " + f"(AP = {average_precision:0.2f})" + ) + elif average_precision is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = f"AP = {average_precision:0.2f}" + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, pr_curve_kwargs + ) + + (line_,) = self.ax_.plot(recall, precision, **line_kwargs) + self.lines_.append(line_) + + if plot_chance_level: + default_chance_level_line_kwargs = { + "label": f"Chance level (AP = {prevalence:0.2f})", + "color": "k", + "linestyle": "--", + } + + if chance_level_kwargs is None: + chance_level_kwargs = {} + elif isinstance(chance_level_kwargs, list): + if len(chance_level_kwargs) > 1: + raise ValueError( + "You intend to plot a single chance level line and " + "provide multiple chance level line keyword " + "arguments. Provide a single dictionary or a list " + "with a single dictionary." + ) + chance_level_kwargs = chance_level_kwargs[0] + + chance_level_line_kwargs = _validate_style_kwargs( + default_chance_level_line_kwargs, chance_level_kwargs + ) + + (chance_level_,) = self.ax_.plot( + (0, 1), (prevalence, prevalence), **chance_level_line_kwargs + ) + self.chance_levels_.append(chance_level_) + else: + self.chance_levels_ = None + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + else: # multiclass-classification + info_pos_label = None # irrelevant for multiclass + if pr_curve_kwargs is None: + pr_curve_kwargs = [{}] * len(self.precision) + elif isinstance(pr_curve_kwargs, list): + if len(pr_curve_kwargs) != len(self.precision): + raise ValueError( + "You intend to plot multiple precision-recall curves. We " + "expect `pr_curve_kwargs` to be a list of dictionaries with " + "the same length as the number of precision-recall curves. " + "Got " + f"{len(pr_curve_kwargs)} instead of " + f"{len(self.precision)}." + ) + else: + raise ValueError( + "You intend to plot multiple precision-recall curves. We expect " + "`pr_curve_kwargs` to be a list of dictionaries of " + f"{len(self.precision)} elements. Got {pr_curve_kwargs!r} instead." + ) + + if plot_chance_level: + if chance_level_kwargs is None: + chance_level_kwargs = [{}] * len(self.precision) + elif isinstance(chance_level_kwargs, list): + if len(chance_level_kwargs) != len(self.precision): + raise ValueError( + "You intend to plot multiple precision-recall curves. We " + "expect `chance_level_kwargs` to be a list of dictionaries " + "with the same length as the number of precision-recall " + "curves. Got " + f"{len(chance_level_kwargs)} instead of " + f"{len(self.precision)}." + ) + else: + raise ValueError( + "You intend to plot multiple precision-recall curves. We " + "expect `chance_level_kwargs` to be a list of dictionaries of " + f"{len(self.precision)} elements. Got {chance_level_kwargs!r} " + "instead." + ) + + for class_idx, class_ in enumerate(self.precision): + precision_class = self.precision[class_] + recall_class = self.recall[class_] + average_precision_class = self.average_precision[class_] + prevalence_class = self.prevalence[class_] + pr_curve_kwargs_class = pr_curve_kwargs[class_idx] + + if len(precision_class) == 1: # single-split + precision = precision_class[0] + recall = recall_class[0] + average_precision = average_precision_class[0] + prevalence = prevalence_class[0] + + default_line_kwargs = {"drawstyle": "steps-post"} + if average_precision is not None and self.data_source in ( + "train", + "test", + ): + default_line_kwargs["label"] = ( + f"{str(class_).title()} - {self.data_source} set " + f"(AP = {average_precision:0.2f})" + ) + elif average_precision is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = ( + f"{str(class_).title()} AP = {average_precision:0.2f}" + ) + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, pr_curve_kwargs_class + ) + + (line_,) = self.ax_.plot(recall, precision, **line_kwargs) + self.lines_.append(line_) + + if plot_chance_level: + chance_level_kwargs_class = chance_level_kwargs[class_idx] + + default_chance_level_line_kwargs = { + "label": ( + f"Chance level - {str(class_).title()} " + f"(AP = {prevalence:0.2f})" + ), + "color": "k", + "linestyle": "--", + } + + chance_level_line_kwargs = _validate_style_kwargs( + default_chance_level_line_kwargs, chance_level_kwargs_class + ) + + (chance_level_,) = self.ax_.plot( + (0, 1), (prevalence, prevalence), **chance_level_line_kwargs + ) + self.chance_levels_.append(chance_level_) + else: + self.chance_levels_ = None + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + xlabel = "Recall" + ylabel = "Precision" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if despine: + _despine_matplotlib_axis(self.ax_) + + self.ax_.legend(loc="lower left", title=estimator_name) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator, + estimator_name, + ml_task, + data_source=None, + pos_label=None, + drop_intermediate=False, + ax=None, + pr_curve_kwargs=None, + plot_chance_level=False, + chance_level_kwargs=None, + despine=True, + ): + """Plot precision-recall curve given binary class predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels. + + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions (as returned by + “decision_function” on some classifiers). + + estimator : estimator instance + The estimator from which `y_pred` is obtained. + + estimator_name : str + Name of the estimator used to plot the precision-recall curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + drop_intermediate : bool, default=False + Whether to drop some suboptimal thresholds which would not appear + on a plotted precision-recall curve. This is useful in order to + create lighter precision-recall curves. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + pr_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the precision-recall curve(s). + + plot_chance_level : bool, default=False + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. + + chance_level_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + """ + pos_label_validated = cls._validate_from_predictions_params( + y_true, y_pred, ml_task=ml_task, pos_label=pos_label + ) + + if ml_task == "binary-classification": + precision, recall, _ = precision_recall_curve( + y_true, + y_pred, + pos_label=pos_label_validated, + drop_intermediate=drop_intermediate, + ) + average_precision = average_precision_score( + y_true, y_pred, pos_label=pos_label_validated + ) + + class_count = Counter(y_true) + prevalence = class_count[pos_label_validated] / sum(class_count.values()) + + precision = {pos_label_validated: [precision]} + recall = {pos_label_validated: [recall]} + average_precision = {pos_label_validated: [average_precision]} + prevalence = {pos_label_validated: [prevalence]} + else: # multiclass-classification + precision, recall, average_precision, prevalence = {}, {}, {}, {} + label_binarizer = LabelBinarizer().fit(estimator.classes_) + y_true_onehot = label_binarizer.transform(y_true) + for class_idx, class_ in enumerate(estimator.classes_): + precision_class, recall_class, _ = precision_recall_curve( + y_true_onehot[:, class_idx], + y_pred[:, class_idx], + pos_label=None, + drop_intermediate=drop_intermediate, + ) + average_precision_class = average_precision_score( + y_true_onehot[:, class_idx], y_pred[:, class_idx] + ) + class_count = Counter(y_true) + prevalence_class = class_count[class_] / sum(class_count.values()) + + precision[class_] = [precision_class] + recall[class_] = [recall_class] + average_precision[class_] = [average_precision_class] + prevalence[class_] = [prevalence_class] + + viz = cls( + precision=precision, + recall=recall, + average_precision=average_precision, + prevalence=prevalence, + estimator_name=estimator_name, + pos_label=pos_label_validated, + data_source=data_source, + ) + + viz.plot( + ax=ax, + estimator_name=estimator_name, + pr_curve_kwargs=pr_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + despine=despine, + ) + + return viz diff --git a/skore/src/skore/sklearn/_plot/prediction_error.py b/skore/src/skore/sklearn/_plot/prediction_error.py new file mode 100644 index 000000000..392f6058b --- /dev/null +++ b/skore/src/skore/sklearn/_plot/prediction_error.py @@ -0,0 +1,318 @@ +import numbers + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.utils.validation import check_random_state + +from skore.externals._sklearn_compat import _safe_indexing +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) + + +class PredictionErrorDisplay(HelpDisplayMixin): + """Visualization of the prediction error of a regression model. + + This tool can display "residuals vs predicted" or "actual vs predicted" + using scatter plots to qualitatively assess the behavior of a regressor, + preferably on held-out data points. + + An instance of this class is should created by + `EstimatorReport.metrics.plot.prediction_error()`. + You should not create an instance of this class directly. + + Parameters + ---------- + ----------z + y_true : ndarray of shape (n_samples,) + True values. + + y_pred : ndarray of shape (n_samples,) + Prediction values. + + estimator_name : str + Name of the estimator. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + Attributes + ---------- + line_ : matplotlib Artist + Optimal line representing `y_true == y_pred`. Therefore, it is a + diagonal line for `kind="predictions"` and a horizontal line for + `kind="residuals"`. + + errors_lines_ : matplotlib Artist or None + Residual lines. If `with_errors=False`, then it is set to `None`. + + scatter_ : matplotlib Artist + Scatter data points. + + ax_ : matplotlib Axes + Axes with the different matplotlib axis. + + figure_ : matplotlib Figure + Figure containing the scatter and lines. + """ + + def __init__(self, *, y_true, y_pred, estimator_name, data_source=None): + self.y_true = y_true + self.y_pred = y_pred + self.estimator_name = estimator_name + self.data_source = data_source + + def plot( + self, + ax=None, + *, + estimator_name=None, + kind="residual_vs_predicted", + scatter_kwargs=None, + line_kwargs=None, + despine=True, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str, default=None + Name of the estimator used to plot the prediction error. If `None`, + we used the inferred name from the estimator. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : PredictionErrorDisplay + Object that stores computed values. + """ + expected_kind = ("actual_vs_predicted", "residual_vs_predicted") + if kind not in expected_kind: + raise ValueError( + f"`kind` must be one of {', '.join(expected_kind)}. " + f"Got {kind!r} instead." + ) + + if scatter_kwargs is None: + scatter_kwargs = {} + if line_kwargs is None: + line_kwargs = {} + + default_scatter_kwargs = {"color": "tab:blue", "alpha": 0.8} + default_line_kwargs = {"color": "black", "alpha": 0.7, "linestyle": "--"} + + scatter_kwargs = _validate_style_kwargs(default_scatter_kwargs, scatter_kwargs) + line_kwargs = _validate_style_kwargs(default_line_kwargs, line_kwargs) + + scatter_kwargs = {**default_scatter_kwargs, **scatter_kwargs} + line_kwargs = {**default_line_kwargs, **line_kwargs} + + if self.data_source in ("train", "test"): + scatter_label = f"{self.data_source.title()} set" + else: + scatter_label = "Data set" + + if estimator_name is None: + estimator_name = self.estimator_name + + if ax is None: + _, ax = plt.subplots() + + if kind == "actual_vs_predicted": + max_value = max(np.max(self.y_true), np.max(self.y_pred)) + min_value = min(np.min(self.y_true), np.min(self.y_pred)) + + x_range = (min_value, max_value) + y_range = (min_value, max_value) + + self.line_ = ax.plot( + [min_value, max_value], + [min_value, max_value], + label="Perfect predictions", + **line_kwargs, + )[0] + + x_data, y_data = self.y_pred, self.y_true + xlabel, ylabel = "Predicted values", "Actual values" + + self.scatter_ = ax.scatter( + x_data, y_data, label=scatter_label, **scatter_kwargs + ) + + # force to have a squared axis + ax.set_aspect("equal", adjustable="datalim") + ax.set_xticks(np.linspace(min_value, max_value, num=5)) + ax.set_yticks(np.linspace(min_value, max_value, num=5)) + else: # kind == "residual_vs_predicted" + x_range = (np.min(self.y_pred), np.max(self.y_pred)) + residuals = self.y_true - self.y_pred + y_range = (np.min(residuals), np.max(residuals)) + + self.line_ = ax.plot( + [np.min(self.y_pred), np.max(self.y_pred)], + [0, 0], + label="Perfect predictions", + **line_kwargs, + )[0] + + self.scatter_ = ax.scatter( + self.y_pred, residuals, label=scatter_label, **scatter_kwargs + ) + xlabel, ylabel = "Predicted values", "Residuals (actual - predicted)" + + ax.set(xlabel=xlabel, ylabel=ylabel) + ax.legend(title=estimator_name) + + self.ax_ = ax + self.figure_ = ax.figure + + if despine: + _despine_matplotlib_axis(self.ax_, x_range=x_range, y_range=y_range) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator, # currently only for consistency with other plots + estimator_name, + ml_task, # FIXME: to be used when having single-output vs. multi-output + data_source=None, + kind="residual_vs_predicted", + subsample=1_000, + random_state=None, + ax=None, + scatter_kwargs=None, + line_kwargs=None, + despine=True, + ): + """Plot the prediction error given the true and predicted targets. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + + y_pred : array-like of shape (n_samples,) + Predicted target values. + + estimator : estimator instance + The estimator from which `y_pred` is obtained. + + estimator_name : str, + The name of the estimator. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + kind : {"actual_vs_predicted", "residual_vs_predicted"}, \ + default="residual_vs_predicted" + The type of plot to draw: + + - "actual_vs_predicted" draws the observed values (y-axis) vs. + the predicted values (x-axis). + - "residual_vs_predicted" draws the residuals, i.e. difference + between observed and predicted values, (y-axis) vs. the predicted + values (x-axis). + + subsample : float, int or None, default=1_000 + Sampling the samples to be shown on the scatter plot. If `float`, + it should be between 0 and 1 and represents the proportion of the + original dataset. If `int`, it represents the number of samples + display on the scatter plot. If `None`, no subsampling will be + applied. by default, 1000 samples or less will be displayed. + + random_state : int or RandomState, default=None + Controls the randomness when `subsample` is not `None`. + See :term:`Glossary ` for details. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + scatter_kwargs : dict, default=None + Dictionary with keywords passed to the `matplotlib.pyplot.scatter` + call. + + line_kwargs : dict, default=None + Dictionary with keyword passed to the `matplotlib.pyplot.plot` + call to draw the optimal line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : PredictionErrorDisplay + Object that stores the computed values. + """ + random_state = check_random_state(random_state) + + n_samples = len(y_true) + if isinstance(subsample, numbers.Integral): + if subsample <= 0: + raise ValueError( + f"When an integer, subsample={subsample} should be positive." + ) + elif isinstance(subsample, numbers.Real): + if subsample <= 0 or subsample >= 1: + raise ValueError( + f"When a floating-point, subsample={subsample} should" + " be in the (0, 1) range." + ) + subsample = int(n_samples * subsample) + + if subsample is not None and subsample < n_samples: + indices = random_state.choice(np.arange(n_samples), size=subsample) + y_true = _safe_indexing(y_true, indices, axis=0) + y_pred = _safe_indexing(y_pred, indices, axis=0) + + viz = cls( + y_true=y_true, + y_pred=y_pred, + estimator_name=estimator_name, + data_source=data_source, + ) + + viz.plot( + ax=ax, + estimator_name=estimator_name, + kind=kind, + scatter_kwargs=scatter_kwargs, + line_kwargs=line_kwargs, + despine=despine, + ) + + return viz diff --git a/skore/src/skore/sklearn/_plot/roc_curve.py b/skore/src/skore/sklearn/_plot/roc_curve.py new file mode 100644 index 000000000..013d491e3 --- /dev/null +++ b/skore/src/skore/sklearn/_plot/roc_curve.py @@ -0,0 +1,399 @@ +from sklearn.metrics import auc, roc_curve +from sklearn.preprocessing import LabelBinarizer + +from skore.sklearn._plot.utils import ( + HelpDisplayMixin, + _ClassifierCurveDisplayMixin, + _despine_matplotlib_axis, + _validate_style_kwargs, +) + + +class RocCurveDisplay(HelpDisplayMixin, _ClassifierCurveDisplayMixin): + """ROC Curve visualization. + + An instance of this class is should created by `EstimatorReport.metrics.plot.roc()`. + You should not create an instance of this class directly. + + Parameters + ---------- + fpr : dict of list of ndarray + False positive rate. The structure is: + + - for binary classification: + - the key is the positive label. + - the value is a list of `ndarray`, each `ndarray` being the false + positive rate. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the false + positive rate. + + tpr : dict of list of ndarray + True positive rate. The structure is: + + - for binary classification: + - the key is the positive label + - the value is a list of `ndarray`, each `ndarray` being the true + positive rate. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `ndarray`, each `ndarray` being the true + positive rate. + + roc_auc : dict of list of float + Area under the ROC curve. The structure is: + + - for binary classification: + - the key is the positive label + - the value is a list of `float`, each `float` being the area under + the ROC curve. + - for multiclass classification: + - the key is the class of interest in an OvR fashion. + - the value is a list of `float`, each `float` being the area under + the ROC curve. + + estimator_name : str + Name of the estimator. + + pos_label : str, default=None + The class considered as positive. Only meaningful for binary classification. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + Attributes + ---------- + ax_ : matplotlib axes + The axes on which the ROC curve is plotted. + + figure_ : matplotlib figure + The figure on which the ROC curve is plotted. + + lines_ : list of matplotlib lines + The lines of the ROC curve. + + chance_level_ : matplotlib line + The chance level line. + """ + + def __init__( + self, + *, + fpr, + tpr, + roc_auc, + estimator_name, + pos_label=None, + data_source=None, + ): + self.estimator_name = estimator_name + self.fpr = fpr + self.tpr = tpr + self.roc_auc = roc_auc + self.pos_label = pos_label + self.data_source = data_source + + def plot( + self, + ax=None, + *, + estimator_name=None, + roc_curve_kwargs=None, + plot_chance_level=True, + chance_level_kwargs=None, + despine=True, + ): + """Plot visualization. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + estimator_name : str, default=None + Name of the estimator used to plot the ROC curve. If `None`, we use + the inferred name from the estimator. + + roc_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the ROC curve(s). + + plot_chance_level : bool, default=True + Whether to plot the chance level. + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : :class:`~sklearn.metrics.RocCurveDisplay` + Object that stores computed values. + """ + self.ax_, self.figure_, estimator_name = self._validate_plot_params( + ax=ax, estimator_name=estimator_name + ) + + self.lines_ = [] + if len(self.fpr) == 1: # binary-classification + if len(self.fpr[self.pos_label]) == 1: # single-split + if roc_curve_kwargs is None: + roc_curve_kwargs = {} + elif isinstance(roc_curve_kwargs, list): + if len(roc_curve_kwargs) > 1: + raise ValueError( + "You intend to plot a single ROC curve and provide " + "multiple ROC curve keyword arguments. Provide a single " + "dictionary or a list with a single dictionary." + ) + roc_curve_kwargs = roc_curve_kwargs[0] + + fpr = self.fpr[self.pos_label][0] + tpr = self.tpr[self.pos_label][0] + roc_auc = self.roc_auc[self.pos_label][0] + + default_line_kwargs = {} + if roc_auc is not None and self.data_source in ("train", "test"): + default_line_kwargs["label"] = ( + f"{self.data_source.title()} set (AUC = {roc_auc:0.2f})" + ) + elif roc_auc is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = f"AUC = {roc_auc:0.2f}" + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, roc_curve_kwargs + ) + + (line_,) = self.ax_.plot(fpr, tpr, **line_kwargs) + self.lines_.append(line_) + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + info_pos_label = ( + f"\n(Positive label: {self.pos_label})" + if self.pos_label is not None + else "" + ) + else: # multiclass-classification + info_pos_label = None # irrelevant for multiclass + if roc_curve_kwargs is None: + roc_curve_kwargs = [{}] * len(self.fpr) + elif isinstance(roc_curve_kwargs, list): + if len(roc_curve_kwargs) != len(self.fpr): + raise ValueError( + "You intend to plot multiple ROC curves. We expect " + "`roc_curve_kwargs` to be a list of dictionaries with the " + "same length as the number of ROC curves. Got " + f"{len(roc_curve_kwargs)} instead of " + f"{len(self.fpr)}." + ) + else: + raise ValueError( + "You intend to plot multiple ROC curves. We expect " + "`roc_curve_kwargs` to be a list of dictionaries of " + f"{len(self.fpr)} elements. Got {roc_curve_kwargs!r} instead." + ) + + for class_idx, class_ in enumerate(self.fpr): + fpr_class = self.fpr[class_] + tpr_class = self.tpr[class_] + roc_auc_class = self.roc_auc[class_] + roc_curve_kwargs_class = roc_curve_kwargs[class_idx] + + if len(fpr_class) == 1: # single-split + fpr = fpr_class[0] + tpr = tpr_class[0] + roc_auc = roc_auc_class[0] + + default_line_kwargs = {} + if roc_auc is not None and self.data_source in ("train", "test"): + default_line_kwargs["label"] = ( + f"{str(class_).title()} - {self.data_source} " + f"set (AUC = {roc_auc:0.2f})" + ) + elif roc_auc is not None: # data_source in (None, "X_y") + default_line_kwargs["label"] = ( + f"{str(class_).title()} AUC = {roc_auc:0.2f}" + ) + + line_kwargs = _validate_style_kwargs( + default_line_kwargs, roc_curve_kwargs_class + ) + + (line_,) = self.ax_.plot(fpr, tpr, **line_kwargs) + self.lines_.append(line_) + else: # cross-validation + raise NotImplementedError( + "We don't support yet cross-validation" + ) # pragma: no cover + + default_chance_level_line_kw = { + "label": "Chance level (AUC = 0.5)", + "color": "k", + "linestyle": "--", + } + + if chance_level_kwargs is None: + chance_level_kwargs = {} + + chance_level_kwargs = _validate_style_kwargs( + default_chance_level_line_kw, chance_level_kwargs + ) + + xlabel = "False Positive Rate" + ylabel = "True Positive Rate" + if info_pos_label: + xlabel += info_pos_label + ylabel += info_pos_label + + self.ax_.set( + xlabel=xlabel, + xlim=(-0.01, 1.01), + ylabel=ylabel, + ylim=(-0.01, 1.01), + aspect="equal", + ) + + if plot_chance_level: + (self.chance_level_,) = self.ax_.plot((0, 1), (0, 1), **chance_level_kwargs) + else: + self.chance_level_ = None + + if despine: + _despine_matplotlib_axis(self.ax_) + + self.ax_.legend(loc="lower right", title=estimator_name) + + @classmethod + def _from_predictions( + cls, + y_true, + y_pred, + *, + estimator, + estimator_name, + ml_task, + data_source=None, + pos_label=None, + drop_intermediate=True, + ax=None, + roc_curve_kwargs=None, + plot_chance_level=True, + chance_level_kwargs=None, + despine=True, + ): + """Private method to create a RocCurveDisplay from predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels in binary classification. + + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive class, + confidence values, or non-thresholded measure of decisions (as returned by + “decision_function” on some classifiers). + + estimator : estimator instance + The estimator from which `y_pred` is obtained. + + estimator_name : str + Name of the estimator used to plot the ROC curve. + + ml_task : {"binary-classification", "multiclass-classification"} + The machine learning task. + + data_source : {"train", "test", "X_y"}, default=None + The data source used to compute the ROC curve. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + drop_intermediate : bool, default=True + Whether to drop intermediate points with identical value. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + roc_curve_kwargs : dict or list of dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the ROC curve(s). + + plot_chance_level : bool, default=True + Whether to plot the chance level. + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=True + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : RocCurveDisplay + Object that stores computed values. + """ + pos_label_validated = cls._validate_from_predictions_params( + y_true, y_pred, ml_task=ml_task, pos_label=pos_label + ) + + if ml_task == "binary-classification": + fpr, tpr, _ = roc_curve( + y_true, + y_pred, + pos_label=pos_label, + drop_intermediate=drop_intermediate, + ) + roc_auc = auc(fpr, tpr) + fpr = {pos_label_validated: [fpr]} + tpr = {pos_label_validated: [tpr]} + roc_auc = {pos_label_validated: [roc_auc]} + else: # multiclass-classification + # OvR fashion to collect fpr, tpr, and roc_auc + fpr, tpr, roc_auc = {}, {}, {} + label_binarizer = LabelBinarizer().fit(estimator.classes_) + y_true_onehot = label_binarizer.transform(y_true) + for class_idx, class_ in enumerate(estimator.classes_): + fpr_class, tpr_class, _ = roc_curve( + y_true_onehot[:, class_idx], + y_pred[:, class_idx], + pos_label=None, + drop_intermediate=drop_intermediate, + ) + roc_auc_class = auc(fpr_class, tpr_class) + + fpr[class_] = [fpr_class] + tpr[class_] = [tpr_class] + roc_auc[class_] = [roc_auc_class] + + viz = cls( + fpr=fpr, + tpr=tpr, + roc_auc=roc_auc, + estimator_name=estimator_name, + pos_label=pos_label_validated, + data_source=data_source, + ) + + viz.plot( + ax=ax, + estimator_name=estimator_name, + roc_curve_kwargs=roc_curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + despine=despine, + ) + + return viz diff --git a/skore/src/skore/sklearn/_plot/utils.py b/skore/src/skore/sklearn/_plot/utils.py new file mode 100644 index 000000000..14e24bedb --- /dev/null +++ b/skore/src/skore/sklearn/_plot/utils.py @@ -0,0 +1,200 @@ +import inspect +from io import StringIO + +import matplotlib.pyplot as plt +from rich.console import Console +from rich.panel import Panel +from rich.tree import Tree +from sklearn.utils.validation import ( + _check_pos_label_consistency, + check_consistent_length, +) + + +class HelpDisplayMixin: + """Mixin class to add help functionality to a class.""" + + def _get_attributes_for_help(self): + """Get the attributes ending with '_' to display in help.""" + attributes = [] + for name in dir(self): + if name.endswith("_") and not name.startswith("_"): + attributes.append(f".{name}") + return sorted(attributes) + + def _get_methods_for_help(self): + """Get the public methods to display in help.""" + methods = inspect.getmembers(self, predicate=inspect.ismethod) + filtered_methods = [] + for name, method in methods: + is_private = name.startswith("_") + is_class_method = inspect.ismethod(method) and method.__self__ is type(self) + is_help_method = name == "help" + if not (is_private or is_class_method or is_help_method): + filtered_methods.append((f".{name}(...)", method)) + return sorted(filtered_methods) + + def _create_help_tree(self): + """Create a rich Tree with attributes and methods.""" + tree = Tree("display") + + attributes = self._get_attributes_for_help() + attr_branch = tree.add("[bold cyan] Attributes[/bold cyan]") + # Ensure figure_ and ax_ are first + sorted_attrs = sorted(attributes) + sorted_attrs.remove(".ax_") + sorted_attrs.remove(".figure_") + sorted_attrs = [".figure_", ".ax_"] + [ + attr for attr in sorted_attrs if attr not in [".figure_", ".ax_"] + ] + for attr in sorted_attrs: + attr_branch.add(attr) + + methods = self._get_methods_for_help() + method_branch = tree.add("[bold cyan]Methods[/bold cyan]") + for name, method in methods: + description = ( + method.__doc__.split("\n")[0] + if method.__doc__ + else "No description available" + ) + method_branch.add(f"{name} - {description}") + + return tree + + def _create_help_panel(self): + return Panel( + self._create_help_tree(), + title=( + f"[bold cyan]📊 {self.__class__.__name__} for {self.estimator_name}" + "[/bold cyan]" + ), + border_style="orange1", + expand=False, + ) + + def help(self): + """Display available attributes and methods using rich.""" + from skore import console # avoid circular import + + console.print(self._create_help_panel()) + + def __repr__(self): + """Return a string representation using rich.""" + console = Console(file=StringIO(), force_terminal=False) + console.print(self._create_help_panel()) + return console.file.getvalue() + + +class _ClassifierCurveDisplayMixin: + """Mixin class to be used in Displays requiring a binary classifier. + + The aim of this class is to centralize some validations regarding the estimator and + the target and gather the response of the estimator. + """ + + def _validate_plot_params(self, *, ax, estimator_name): + if ax is None: + _, ax = plt.subplots() + + estimator_name = ( + self.estimator_name if estimator_name is None else estimator_name + ) + return ax, ax.figure, estimator_name + + @classmethod + def _validate_from_predictions_params( + cls, + y_true, + y_pred, + *, + ml_task, + sample_weight=None, + pos_label=None, + ): + check_consistent_length(y_true, y_pred, sample_weight) + + if ml_task == "binary-classification": + pos_label = _check_pos_label_consistency(pos_label, y_true) + + return pos_label + + +def _despine_matplotlib_axis(ax, *, x_range=(0, 1), y_range=(0, 1)): + """Despine the matplotlib axis. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The matplotlib axis to despine. + x_range : tuple of float, default=(0, 1) + The range of the x-axis. + y_range : tuple of float, default=(0, 1) + The range of the y-axis. + """ + for s in ["top", "right"]: + ax.spines[s].set_visible(False) + ax.spines["bottom"].set_bounds(x_range[0], x_range[1]) + ax.spines["left"].set_bounds(y_range[0], y_range[1]) + + +def _validate_style_kwargs(default_style_kwargs, user_style_kwargs): + """Create valid style kwargs by avoiding Matplotlib alias errors. + + Matplotlib raises an error when, for example, 'color' and 'c', or 'linestyle' and + 'ls', are specified together. To avoid this, we automatically keep only the one + specified by the user and raise an error if the user specifies both. + + Parameters + ---------- + default_style_kwargs : dict + The Matplotlib style kwargs used by default in the scikit-learn display. + user_style_kwargs : dict + The user-defined Matplotlib style kwargs. + + Returns + ------- + valid_style_kwargs : dict + The validated style kwargs taking into account both default and user-defined + Matplotlib style kwargs. + """ + invalid_to_valid_kw = { + "ls": "linestyle", + "c": "color", + "ec": "edgecolor", + "fc": "facecolor", + "lw": "linewidth", + "mec": "markeredgecolor", + "mfcalt": "markerfacecoloralt", + "ms": "markersize", + "mew": "markeredgewidth", + "mfc": "markerfacecolor", + "aa": "antialiased", + "ds": "drawstyle", + "font": "fontproperties", + "family": "fontfamily", + "name": "fontname", + "size": "fontsize", + "stretch": "fontstretch", + "style": "fontstyle", + "variant": "fontvariant", + "weight": "fontweight", + "ha": "horizontalalignment", + "va": "verticalalignment", + "ma": "multialignment", + } + for invalid_key, valid_key in invalid_to_valid_kw.items(): + if invalid_key in user_style_kwargs and valid_key in user_style_kwargs: + raise TypeError( + f"Got both {invalid_key} and {valid_key}, which are aliases of one " + "another" + ) + valid_style_kwargs = default_style_kwargs.copy() + + for key in user_style_kwargs: + if key in invalid_to_valid_kw: + valid_style_kwargs[invalid_to_valid_kw[key]] = user_style_kwargs[key] + else: + valid_style_kwargs[key] = user_style_kwargs[key] + + return valid_style_kwargs diff --git a/skore/src/skore/utils/_accessor.py b/skore/src/skore/utils/_accessor.py new file mode 100644 index 000000000..aafcc8d92 --- /dev/null +++ b/skore/src/skore/utils/_accessor.py @@ -0,0 +1,15 @@ +def _check_supported_ml_task(supported_ml_tasks): + def check(accessor): + supported_task = any( + task in accessor._parent._ml_task for task in supported_ml_tasks + ) + + if not supported_task: + raise AttributeError( + f"The {accessor._parent._ml_task} task is not a supported task by " + f"function called. The supported tasks are {supported_ml_tasks}." + ) + + return True + + return check diff --git a/skore/tests/conftest.py b/skore/tests/conftest.py index 6b2e3e0ad..82f8a2b48 100644 --- a/skore/tests/conftest.py +++ b/skore/tests/conftest.py @@ -7,6 +7,13 @@ from skore.view.view_repository import ViewRepository +def pytest_configure(config): + # Use matplotlib agg backend during the tests including doctests + import matplotlib + + matplotlib.use("agg") + + @pytest.fixture def mock_now(): return datetime.now(tz=timezone.utc) @@ -37,3 +44,21 @@ def in_memory_project(): item_repository=item_repository, view_repository=view_repository, ) + + +@pytest.fixture(scope="function") +def pyplot(): + """Setup and teardown fixture for matplotlib. + + This fixture closes the figures before and after running the functions. + + Returns + ------- + pyplot : module + The ``matplotlib.pyplot`` module. + """ + from matplotlib import pyplot + + pyplot.close("all") + yield pyplot + pyplot.close("all") diff --git a/skore/tests/unit/sklearn/plot/test_common.py b/skore/tests/unit/sklearn/plot/test_common.py new file mode 100644 index 000000000..8384f1538 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_common.py @@ -0,0 +1,80 @@ +import pytest +from sklearn.datasets import make_classification, make_regression +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport + + +@pytest.mark.parametrize( + "plot_func, estimator, dataset", + [ + ("roc", LogisticRegression(), make_classification(random_state=42)), + ( + "precision_recall", + LogisticRegression(), + make_classification(random_state=42), + ), + ("prediction_error", LinearRegression(), make_regression(random_state=42)), + ], +) +def test_display_help(pyplot, capsys, plot_func, estimator, dataset): + """Check that the help method writes to the console.""" + + X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = getattr(report.metrics.plot, plot_func)() + + display.help() + captured = capsys.readouterr() + assert f"📊 {display.__class__.__name__}" in captured.out + + +@pytest.mark.parametrize( + "plot_func, estimator, dataset", + [ + ("roc", LogisticRegression(), make_classification(random_state=42)), + ( + "precision_recall", + LogisticRegression(), + make_classification(random_state=42), + ), + ("prediction_error", LinearRegression(), make_regression(random_state=42)), + ], +) +def test_display_repr(pyplot, plot_func, estimator, dataset): + """Check that __repr__ returns a string starting with the expected prefix.""" + X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = getattr(report.metrics.plot, plot_func)() + + repr_str = repr(display) + assert f"📊 {display.__class__.__name__}" in repr_str + + +@pytest.mark.parametrize( + "plot_func, estimator, dataset", + [ + ("roc", LogisticRegression(), make_classification(random_state=42)), + ( + "precision_recall", + LogisticRegression(), + make_classification(random_state=42), + ), + ("prediction_error", LinearRegression(), make_regression(random_state=42)), + ], +) +def test_display_provide_ax(pyplot, plot_func, estimator, dataset): + """Check that we can provide an ax to the plot method.""" + X_train, X_test, y_train, y_test = train_test_split(*dataset, random_state=42) + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = getattr(report.metrics.plot, plot_func)() + + _, ax = pyplot.subplots() + display.plot(ax=ax) + assert display.ax_ is ax diff --git a/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py new file mode 100644 index 000000000..ec42899f6 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_precision_recall_curve.py @@ -0,0 +1,238 @@ +import matplotlib as mpl +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import PrecisionRecallCurveDisplay + + +@pytest.fixture +def binary_classification_data(): + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +@pytest.fixture +def multiclass_classification_data(): + X, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +def test_precision_recall_curve_display_binary_classification( + pyplot, binary_classification_data +): + """Check the attributes and default plotting behaviour of the + precision-recall curve plot with binary data. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + assert isinstance(display, PrecisionRecallCurveDisplay) + + # check the structure of the attributes + for attr_name in ("precision", "recall", "average_precision", "prevalence"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == 1 + + attr = getattr(display, attr_name) + assert list(attr.keys()) == [estimator.classes_[1]] + assert list(attr.keys()) == [display.pos_label] + assert isinstance(attr[estimator.classes_[1]], list) + assert len(attr[estimator.classes_[1]]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == 1 + precision_recall_curve_mpl = display.lines_[0] + assert isinstance(precision_recall_curve_mpl, mpl.lines.Line2D) + assert ( + precision_recall_curve_mpl.get_label() + == f"Test set (AP = {display.average_precision[estimator.classes_[1]][0]:0.2f})" + ) + assert precision_recall_curve_mpl.get_color() == "#1f77b4" # tab:blue in hex + + assert display.chance_levels_ is None + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 1 + + assert display.ax_.get_xlabel() == "Recall\n(Positive label: 1)" + assert display.ax_.get_ylabel() == "Precision\n(Positive label: 1)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_precision_recall_curve_display_data_source(pyplot, binary_classification_data): + """Check that we can pass the `data_source` argument to the precision-recall + curve plot. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall(data_source="train") + assert display.lines_[0].get_label() == "Train set (AP = 1.00)" + + display = report.metrics.plot.precision_recall( + data_source="X_y", X=X_train, y=y_train + ) + assert display.lines_[0].get_label() == "AP = 1.00" + + +def test_precision_recall_curve_display_multiclass_classification( + pyplot, multiclass_classification_data +): + """Check the attributes and default plotting behaviour of the precision-recall + curve plot with multiclass data. + """ + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + assert isinstance(display, PrecisionRecallCurveDisplay) + + # check the structure of the attributes + for attr_name in ("precision", "recall", "average_precision", "prevalence"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == len(estimator.classes_) + + attr = getattr(display, attr_name) + for class_label in estimator.classes_: + assert isinstance(attr[class_label], list) + assert len(attr[class_label]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == len(estimator.classes_) + default_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] + for class_label, expected_color in zip(estimator.classes_, default_colors): + precision_recall_curve_mpl = display.lines_[class_label] + assert isinstance(precision_recall_curve_mpl, mpl.lines.Line2D) + assert precision_recall_curve_mpl.get_label() == ( + f"{str(class_label).title()} - test set " + f"(AP = {display.average_precision[class_label][0]:0.2f})" + ) + assert precision_recall_curve_mpl.get_color() == expected_color + + assert display.chance_levels_ is None + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 3 + + assert display.ax_.get_xlabel() == "Recall" + assert display.ax_.get_ylabel() == "Precision" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_precision_recall_curve_display_pr_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we can pass keyword arguments to the precision-recall curve plot.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + for pr_curve_kwargs in ({"color": "red"}, [{"color": "red"}]): + display.plot( + pr_curve_kwargs=pr_curve_kwargs, + plot_chance_level=True, + chance_level_kwargs={"color": "blue"}, + ) + + assert display.lines_[0].get_color() == "red" + assert display.chance_levels_[0].get_color() == "blue" + + display.plot(plot_chance_level=True) + assert display.chance_levels_[0].get_color() == "k" + + display.plot(plot_chance_level=True, chance_level_kwargs=[{"color": "red"}]) + assert display.chance_levels_[0].get_color() == "red" + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + display.plot( + pr_curve_kwargs=[dict(color="red"), dict(color="blue"), dict(color="green")], + plot_chance_level=True, + chance_level_kwargs=[ + dict(color="red"), + dict(color="blue"), + dict(color="green"), + ], + ) + assert display.lines_[0].get_color() == "red" + assert display.lines_[1].get_color() == "blue" + assert display.lines_[2].get_color() == "green" + assert display.chance_levels_[0].get_color() == "red" + assert display.chance_levels_[1].get_color() == "blue" + assert display.chance_levels_[2].get_color() == "green" + + display.plot(plot_chance_level=True) + for chance_level in display.chance_levels_: + assert chance_level.get_color() == "k" + + display.plot(despine=False) + assert display.ax_.spines["top"].get_visible() + assert display.ax_.spines["right"].get_visible() + + +def test_precision_recall_curve_display_plot_error_wrong_pr_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we raise a proper error message when passing an inappropriate + value for the `roc_curve_kwargs` argument. + """ + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + err_msg = ( + "You intend to plot a single precision-recall curve and provide multiple " + "precision-recall curve keyword arguments" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(pr_curve_kwargs=[{}, {}]) + + err_msg = ( + "You intend to plot a single chance level line and provide multiple chance " + "level line keyword arguments" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(plot_chance_level=True, chance_level_kwargs=[{}, {}]) + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.precision_recall() + err_msg = "You intend to plot multiple precision-recall curves." + with pytest.raises(ValueError, match=err_msg): + display.plot(pr_curve_kwargs=[{}, {}]) + + with pytest.raises(ValueError, match=err_msg): + display.plot(pr_curve_kwargs={}) + + err_msg = ( + "You intend to plot multiple precision-recall curves. We expect " + "`chance_level_kwargs` to be a list" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(plot_chance_level=True, chance_level_kwargs=[{}, {}]) + + with pytest.raises(ValueError, match=err_msg): + display.plot(plot_chance_level=True, chance_level_kwargs={}) diff --git a/skore/tests/unit/sklearn/plot/test_prediction_error.py b/skore/tests/unit/sklearn/plot/test_prediction_error.py new file mode 100644 index 000000000..30db39c00 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_prediction_error.py @@ -0,0 +1,136 @@ +import matplotlib as mpl +import numpy as np +import pytest +from sklearn.datasets import make_regression +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import PredictionErrorDisplay + + +@pytest.fixture +def regression_data(): + X, y = make_regression(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LinearRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +@pytest.mark.parametrize( + "params, err_msg", + [ + ({"subsample": -1}, "When an integer, subsample=-1 should be"), + ({"subsample": 20.0}, "When a floating-point, subsample=20.0 should be"), + ({"subsample": -20.0}, "When a floating-point, subsample=-20.0 should be"), + ({"kind": "xxx"}, "`kind` must be one of"), + ], +) +def test_prediction_error_display_raise_error(pyplot, params, err_msg, regression_data): + """Check that we raise the proper error when making the parameters + validation.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.plot.prediction_error(**params) + + +def test_prediction_error_display_regression(pyplot, regression_data): + """Check the attributes and default plotting behaviour of the prediction error plot + with regression data.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error() + assert isinstance(display, PredictionErrorDisplay) + + # check the structure of the attributes + assert isinstance(display.y_true, np.ndarray) + assert isinstance(display.y_pred, np.ndarray) + np.testing.assert_allclose(display.y_true, y_test) + np.testing.assert_allclose(display.y_pred, estimator.predict(X_test)) + assert display.data_source == "test" + + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_label() == "Perfect predictions" + assert display.line_.get_color() == "black" + + assert isinstance(display.scatter_, mpl.collections.PathCollection) + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 2 + + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Residuals (actual - predicted)" + + +def test_prediction_error_display_regression_kind(pyplot, regression_data): + """Check the attributes when switching to the "actual_vs_predicted" kind.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error(kind="actual_vs_predicted") + assert isinstance(display, PredictionErrorDisplay) + + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_label() == "Perfect predictions" + assert display.line_.get_color() == "black" + + assert isinstance(display.scatter_, mpl.collections.PathCollection) + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 2 + + assert display.ax_.get_xlabel() == "Predicted values" + assert display.ax_.get_ylabel() == "Actual values" + + assert display.ax_.get_xlim() == display.ax_.get_ylim() + assert display.ax_.get_aspect() in ("equal", 1.0) + + +def test_prediction_error_display_data_source(pyplot, regression_data): + """Check that we can pass the `data_source` argument to the prediction error + plot.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error(data_source="train") + assert display.line_.get_label() == "Perfect predictions" + assert display.scatter_.get_label() == "Train set" + + display = report.metrics.plot.prediction_error( + data_source="X_y", X=X_train, y=y_train + ) + assert display.line_.get_label() == "Perfect predictions" + assert display.scatter_.get_label() == "Data set" + + +def test_prediction_error_display_kwargs(pyplot, regression_data): + """Check that we can pass keyword arguments to the prediction error plot.""" + estimator, X_train, X_test, y_train, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.prediction_error() + display.plot(scatter_kwargs={"color": "red"}, line_kwargs={"color": "blue"}) + np.testing.assert_allclose(display.scatter_.get_facecolor(), [[1, 0, 0, 0.8]]) + assert display.line_.get_color() == "blue" + + display.plot(despine=False) + assert display.ax_.spines["top"].get_visible() + assert display.ax_.spines["right"].get_visible() + + expected_subsample = 10 + display = report.metrics.plot.prediction_error(subsample=expected_subsample) + assert len(display.scatter_.get_offsets()) == expected_subsample + + expected_subsample = int(X_test.shape[0] * 0.5) + display = report.metrics.plot.prediction_error(subsample=0.5) + assert len(display.scatter_.get_offsets()) == expected_subsample diff --git a/skore/tests/unit/sklearn/plot/test_roc_curve.py b/skore/tests/unit/sklearn/plot/test_roc_curve.py new file mode 100644 index 000000000..9f2f09678 --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_roc_curve.py @@ -0,0 +1,199 @@ +import matplotlib as mpl +import pytest +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from skore import EstimatorReport +from skore.sklearn._plot import RocCurveDisplay + + +@pytest.fixture +def binary_classification_data(): + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +@pytest.fixture +def multiclass_classification_data(): + X, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + return LogisticRegression().fit(X_train, y_train), X_train, X_test, y_train, y_test + + +def test_roc_curve_display_binary_classification(pyplot, binary_classification_data): + """Check the attributes and default plotting behaviour of the ROC curve plot with + binary data.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + assert isinstance(display, RocCurveDisplay) + + # check the structure of the attributes + for attr_name in ("fpr", "tpr", "roc_auc"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == 1 + + attr = getattr(display, attr_name) + assert list(attr.keys()) == [estimator.classes_[1]] + assert list(attr.keys()) == [display.pos_label] + assert isinstance(attr[estimator.classes_[1]], list) + assert len(attr[estimator.classes_[1]]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == 1 + roc_curve_mpl = display.lines_[0] + assert isinstance(roc_curve_mpl, mpl.lines.Line2D) + assert ( + roc_curve_mpl.get_label() + == f"Test set (AUC = {display.roc_auc[estimator.classes_[1]][0]:0.2f})" + ) + assert roc_curve_mpl.get_color() == "#1f77b4" # tab:blue in hex + + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + assert display.chance_level_.get_color() == "k" + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 2 + + assert display.ax_.get_xlabel() == "False Positive Rate\n(Positive label: 1)" + assert display.ax_.get_ylabel() == "True Positive Rate\n(Positive label: 1)" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_roc_curve_display_multiclass_classification( + pyplot, multiclass_classification_data +): + """Check the attributes and default plotting behaviour of the ROC curve plot with + multiclass data.""" + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + assert isinstance(display, RocCurveDisplay) + + # check the structure of the attributes + for attr_name in ("fpr", "tpr", "roc_auc"): + assert isinstance(getattr(display, attr_name), dict) + assert len(getattr(display, attr_name)) == len(estimator.classes_) + + attr = getattr(display, attr_name) + for class_label in estimator.classes_: + assert isinstance(attr[class_label], list) + assert len(attr[class_label]) == 1 + + assert isinstance(display.lines_, list) + assert len(display.lines_) == len(estimator.classes_) + default_colors = ["#1f77b4", "#ff7f0e", "#2ca02c"] + for class_label, expected_color in zip(estimator.classes_, default_colors): + roc_curve_mpl = display.lines_[class_label] + assert isinstance(roc_curve_mpl, mpl.lines.Line2D) + assert roc_curve_mpl.get_label() == ( + f"{str(class_label).title()} - test set " + f"(AUC = {display.roc_auc[class_label][0]:0.2f})" + ) + assert roc_curve_mpl.get_color() == expected_color + + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + assert display.chance_level_.get_color() == "k" + + assert isinstance(display.ax_, mpl.axes.Axes) + legend = display.ax_.get_legend() + assert legend.get_title().get_text() == estimator.__class__.__name__ + assert len(legend.get_texts()) == 4 + + assert display.ax_.get_xlabel() == "False Positive Rate" + assert display.ax_.get_ylabel() == "True Positive Rate" + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + +def test_roc_curve_display_data_source(pyplot, binary_classification_data): + """Check that we can pass the `data_source` argument to the ROC curve plot.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc(data_source="train") + assert display.lines_[0].get_label() == "Train set (AUC = 1.00)" + + display = report.metrics.plot.roc(data_source="X_y", X=X_train, y=y_train) + assert display.lines_[0].get_label() == "AUC = 1.00" + + +def test_roc_curve_display_plot_error_wrong_roc_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we raise a proper error message when passing an inappropriate + value for the `roc_curve_kwargs` argument.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + err_msg = ( + "You intend to plot a single ROC curve and provide multiple ROC curve " + "keyword arguments" + ) + with pytest.raises(ValueError, match=err_msg): + display.plot(roc_curve_kwargs=[{}, {}]) + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + err_msg = "You intend to plot multiple ROC curves." + with pytest.raises(ValueError, match=err_msg): + display.plot(roc_curve_kwargs=[{}, {}]) + + with pytest.raises(ValueError, match=err_msg): + display.plot(roc_curve_kwargs={}) + + +def test_roc_curve_display_roc_curve_kwargs( + pyplot, binary_classification_data, multiclass_classification_data +): + """Check that we can pass keyword arguments to the ROC curve plot.""" + estimator, X_train, X_test, y_train, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + display.plot( + roc_curve_kwargs={"color": "red"}, chance_level_kwargs={"color": "blue"} + ) + + assert display.lines_[0].get_color() == "red" + assert display.chance_level_.get_color() == "blue" + + estimator, X_train, X_test, y_train, y_test = multiclass_classification_data + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + display = report.metrics.plot.roc() + display.plot( + roc_curve_kwargs=[dict(color="red"), dict(color="blue"), dict(color="green")], + chance_level_kwargs={"color": "blue"}, + ) + assert display.lines_[0].get_color() == "red" + assert display.lines_[1].get_color() == "blue" + assert display.lines_[2].get_color() == "green" + assert display.chance_level_.get_color() == "blue" + + display.plot(plot_chance_level=False) + assert display.chance_level_ is None + + display.plot(despine=False) + assert display.ax_.spines["top"].get_visible() + assert display.ax_.spines["right"].get_visible() diff --git a/skore/tests/unit/sklearn/plot/test_utils.py b/skore/tests/unit/sklearn/plot/test_utils.py new file mode 100644 index 000000000..fb54322bb --- /dev/null +++ b/skore/tests/unit/sklearn/plot/test_utils.py @@ -0,0 +1,65 @@ +import pytest +from skore.sklearn._plot.utils import _validate_style_kwargs + + +@pytest.mark.parametrize( + "default_kwargs, user_kwargs, expected", + [ + ( + {"color": "blue", "linewidth": 2}, + {"linestyle": "dashed"}, + {"color": "blue", "linewidth": 2, "linestyle": "dashed"}, + ), + ( + {"color": "blue", "linestyle": "solid"}, + {"c": "red", "ls": "dashed"}, + {"color": "red", "linestyle": "dashed"}, + ), + ( + {"label": "xxx", "color": "k", "linestyle": "--"}, + {"ls": "-."}, + {"label": "xxx", "color": "k", "linestyle": "-."}, + ), + ({}, {}, {}), + ( + {}, + { + "ls": "dashed", + "c": "red", + "ec": "black", + "fc": "yellow", + "lw": 2, + "mec": "green", + "mfcalt": "blue", + "ms": 5, + }, + { + "linestyle": "dashed", + "color": "red", + "edgecolor": "black", + "facecolor": "yellow", + "linewidth": 2, + "markeredgecolor": "green", + "markerfacecoloralt": "blue", + "markersize": 5, + }, + ), + ], +) +def test_validate_style_kwargs(default_kwargs, user_kwargs, expected): + """Check the behaviour of `validate_style_kwargs` with various type of entries.""" + result = _validate_style_kwargs(default_kwargs, user_kwargs) + assert result == expected, ( + "The validation of style keywords does not provide the expected results: " + f"Got {result} instead of {expected}." + ) + + +@pytest.mark.parametrize( + "default_kwargs, user_kwargs", + [({}, {"ls": 2, "linestyle": 3}), ({}, {"c": "r", "color": "blue"})], +) +def test_validate_style_kwargs_error(default_kwargs, user_kwargs): + """Check that `validate_style_kwargs` raises TypeError""" + with pytest.raises(TypeError): + _validate_style_kwargs(default_kwargs, user_kwargs) diff --git a/skore/tests/unit/sklearn/test_estimator.py b/skore/tests/unit/sklearn/test_estimator.py new file mode 100644 index 000000000..8fd2c54aa --- /dev/null +++ b/skore/tests/unit/sklearn/test_estimator.py @@ -0,0 +1,927 @@ +import re +from copy import deepcopy + +import joblib +import numpy as np +import pandas as pd +import pytest +from sklearn.base import clone +from sklearn.cluster import KMeans +from sklearn.datasets import make_classification, make_regression +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LinearRegression, LogisticRegression +from sklearn.metrics import make_scorer, median_absolute_error, r2_score, rand_score +from sklearn.model_selection import train_test_split +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC +from sklearn.utils.validation import check_is_fitted +from skore import EstimatorReport +from skore.sklearn._estimator.utils import _check_supported_estimator +from skore.sklearn._plot import RocCurveDisplay + + +@pytest.fixture +def binary_classification_data(): + """Create a binary classification dataset and return fitted estimator and data.""" + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return RandomForestClassifier().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def binary_classification_data_svc(): + """Create a binary classification dataset and return fitted estimator and data. + The estimator is a SVC that does not support `predict_proba`. + """ + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return SVC().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def multiclass_classification_data(): + """Create a multiclass classification dataset and return fitted estimator and + data.""" + X, y = make_classification( + n_classes=3, n_clusters_per_class=1, random_state=42, n_informative=10 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return RandomForestClassifier().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def multiclass_classification_data_svc(): + """Create a multiclass classification dataset and return fitted estimator and + data. The estimator is a SVC that does not support `predict_proba`. + """ + X, y = make_classification( + n_classes=3, n_clusters_per_class=1, random_state=42, n_informative=10 + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return SVC().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def binary_classification_data_pipeline(): + """Create a binary classification dataset and return fitted pipeline and data.""" + X, y = make_classification(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + estimator = Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression())]) + return estimator.fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def regression_data(): + """Create a regression dataset and return fitted estimator and data.""" + X, y = make_regression(random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return LinearRegression().fit(X_train, y_train), X_test, y_test + + +@pytest.fixture +def regression_multioutput_data(): + """Create a regression dataset and return fitted estimator and data.""" + X, y = make_regression(n_targets=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42 + ) + return LinearRegression().fit(X_train, y_train), X_test, y_test + + +def test_check_supported_estimator(): + """Test the behaviour of `_check_supported_estimator`.""" + + class MockParent: + def __init__(self, estimator): + self.estimator = estimator + + class MockAccessor: + def __init__(self, parent): + self._parent = parent + + parent = MockParent(LogisticRegression()) + accessor = MockAccessor(parent) + check = _check_supported_estimator((LogisticRegression,)) + assert check(accessor) + + pipeline = Pipeline([("clf", LogisticRegression())]) + parent = MockParent(pipeline) + accessor = MockAccessor(parent) + assert check(accessor) + + parent = MockParent(RandomForestClassifier()) + accessor = MockAccessor(parent) + err_msg = ( + "The RandomForestClassifier estimator is not supported by the function called." + ) + with pytest.raises(AttributeError, match=err_msg): + check(accessor) + + +######################################################################################## +# Check the general behaviour of the report +######################################################################################## + + +@pytest.mark.parametrize("fit", [True, "auto"]) +def test_estimator_not_fitted(fit): + """Test that an error is raised when trying to create a report from an unfitted + estimator and no data are provided to fit the estimator. + """ + estimator = LinearRegression() + err_msg = "The training data is required to fit the estimator. " + with pytest.raises(ValueError, match=err_msg): + EstimatorReport(estimator, fit=fit) + + +@pytest.mark.parametrize("fit", [True, "auto"]) +def test_estimator_report_from_unfitted_estimator(fit): + """Check the general behaviour of passing an unfitted estimator and training + data.""" + X, y = make_regression(random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + estimator = LinearRegression() + report = EstimatorReport( + estimator, + fit=fit, + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + ) + + check_is_fitted(report.estimator) + assert report.estimator is not estimator # the estimator should be cloned + + assert report.X_train is X_train + assert report.y_train is y_train + assert report.X_test is X_test + assert report.y_test is y_test + + err_msg = "attribute is immutable" + with pytest.raises(AttributeError, match=err_msg): + report.estimator = LinearRegression() + with pytest.raises(AttributeError, match=err_msg): + report.X_train = X_train + with pytest.raises(AttributeError, match=err_msg): + report.y_train = y_train + + +@pytest.mark.parametrize("fit", [False, "auto"]) +def test_estimator_report_from_fitted_estimator(binary_classification_data, fit): + """Check the general behaviour of passing an already fitted estimator without + refitting it.""" + estimator, X, y = binary_classification_data + report = EstimatorReport(estimator, fit=fit, X_test=X, y_test=y) + + assert report.estimator is estimator # we should not clone the estimator + assert report.X_train is None + assert report.y_train is None + assert report.X_test is X + assert report.y_test is y + + err_msg = "attribute is immutable" + with pytest.raises(AttributeError, match=err_msg): + report.estimator = LinearRegression() + with pytest.raises(AttributeError, match=err_msg): + report.X_train = X + with pytest.raises(AttributeError, match=err_msg): + report.y_train = y + + +def test_estimator_report_from_fitted_pipeline(binary_classification_data_pipeline): + """Check the general behaviour of passing an already fitted pipeline without + refitting it. + """ + estimator, X, y = binary_classification_data_pipeline + report = EstimatorReport(estimator, X_test=X, y_test=y) + + assert report.estimator is estimator # we should not clone the estimator + assert report.estimator_name == estimator[-1].__class__.__name__ + assert report.X_train is None + assert report.y_train is None + assert report.X_test is X + assert report.y_test is y + + +def test_estimator_report_invalidate_cache_data(binary_classification_data): + """Check that we invalidate the cache when the data is changed.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + for attribute in ("X_test", "y_test"): + report._cache["mocking"] = "mocking" # mock writing to cache + setattr(report, attribute, None) + assert report._cache == {} + + +@pytest.mark.parametrize( + "Estimator, X_test, y_test, supported_plot_methods, not_supported_plot_methods", + [ + ( + RandomForestClassifier(), + *make_classification(random_state=42), + ["roc", "precision_recall"], + ["prediction_error"], + ), + ( + RandomForestClassifier(), + *make_classification(n_classes=3, n_clusters_per_class=1, random_state=42), + ["roc", "precision_recall"], + ["prediction_error"], + ), + ( + LinearRegression(), + *make_regression(random_state=42), + ["prediction_error"], + ["roc", "precision_recall"], + ), + ], +) +def test_estimator_report_check_support_plot( + Estimator, X_test, y_test, supported_plot_methods, not_supported_plot_methods +): + """Check that the available plot methods are correctly registered.""" + estimator = Estimator.fit(X_test, y_test) + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + for supported_plot_method in supported_plot_methods: + assert hasattr(report.metrics.plot, supported_plot_method) + + for not_supported_plot_method in not_supported_plot_methods: + assert not hasattr(report.metrics.plot, not_supported_plot_method) + + +def test_estimator_report_help(capsys, binary_classification_data): + """Check that the help method writes to the console.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + report.help() + captured = capsys.readouterr() + assert ( + f"📓 Tools to diagnose estimator {estimator.__class__.__name__}" in captured.out + ) + + +def test_estimator_report_repr(binary_classification_data): + """Check that __repr__ returns a string starting with the expected prefix.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + repr_str = repr(report) + assert f"📓 Tools to diagnose estimator {estimator.__class__.__name__}" in repr_str + + +@pytest.mark.parametrize( + "fixture_name", ["binary_classification_data", "regression_data"] +) +def test_estimator_report_cache_predictions(request, fixture_name): + """Check that calling cache_predictions fills the cache.""" + estimator, X_test, y_test = request.getfixturevalue(fixture_name) + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + + assert report._cache == {} + report.cache_predictions() + assert report._cache != {} + stored_cache = deepcopy(report._cache) + report.cache_predictions() + # check that the keys are exactly the same + assert report._cache.keys() == stored_cache.keys() + + +######################################################################################## +# Check the plot methods +######################################################################################## + + +def test_estimator_report_plot_help(capsys, binary_classification_data): + """Check that the help method writes to the console.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + report.metrics.plot.help() + captured = capsys.readouterr() + assert "🎨 Available plot methods" in captured.out + + +def test_estimator_report_plot_repr(binary_classification_data): + """Check that __repr__ returns a string starting with the expected prefix.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + repr_str = repr(report.metrics.plot) + assert "🎨 Available plot methods" in repr_str + + +def test_estimator_report_plot_roc(binary_classification_data): + """Check that the ROC plot method works.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert isinstance(report.metrics.plot.roc(), RocCurveDisplay) + + +@pytest.mark.parametrize("display", ["roc", "precision_recall"]) +def test_estimator_report_display_binary_classification( + pyplot, binary_classification_data, display +): + """General behaviour of the function creating display on binary classification.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)() + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)() + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["prediction_error"]) +def test_estimator_report_display_regression(pyplot, regression_data, display): + """General behaviour of the function creating display on regression.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)() + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)() + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["roc", "precision_recall"]) +def test_estimator_report_display_binary_classification_external_data( + pyplot, binary_classification_data, display +): + """General behaviour of the function creating display on binary classification + when passing external data. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["prediction_error"]) +def test_estimator_report_display_regression_external_data( + pyplot, regression_data, display +): + """General behaviour of the function creating display on regression when passing + external data. + """ + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is display_second_call + + +@pytest.mark.parametrize("display", ["roc", "precision_recall"]) +def test_estimator_report_display_binary_classification_switching_data_source( + pyplot, binary_classification_data, display +): + """Check that we don't hit the cache when switching the data source.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)(data_source="test") + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)(data_source="train") + assert display_first_call is not display_second_call + display_third_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is not display_third_call + assert display_second_call is not display_third_call + + +@pytest.mark.parametrize("display", ["prediction_error"]) +def test_estimator_report_display_regression_switching_data_source( + pyplot, regression_data, display +): + """Check that we don't hit the cache when switching the data source.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + assert hasattr(report.metrics.plot, display) + display_first_call = getattr(report.metrics.plot, display)(data_source="test") + assert report._cache != {} + display_second_call = getattr(report.metrics.plot, display)(data_source="train") + assert display_first_call is not display_second_call + display_third_call = getattr(report.metrics.plot, display)( + data_source="X_y", X=X_test, y=y_test + ) + assert display_first_call is not display_third_call + assert display_second_call is not display_third_call + + +######################################################################################## +# Check the metrics methods +######################################################################################## + + +def test_estimator_report_metrics_help(capsys, binary_classification_data): + """Check that the help method writes to the console.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + report.metrics.help() + captured = capsys.readouterr() + assert "📏 Available metrics methods" in captured.out + + +def test_estimator_report_metrics_repr(binary_classification_data): + """Check that __repr__ returns a string starting with the expected prefix.""" + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + repr_str = repr(report.metrics) + assert "📏 Available metrics methods" in repr_str + + +@pytest.mark.parametrize( + "metric", ["accuracy", "precision", "recall", "brier_score", "roc_auc", "log_loss"] +) +def test_estimator_report_metrics_binary_classification( + binary_classification_data, metric +): + """Check the behaviour of the metrics methods available for binary + classification. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, metric) + result = getattr(report.metrics, metric)() + assert isinstance(result, pd.DataFrame) + # check that we hit the cache + result_with_cache = getattr(report.metrics, metric)() + pd.testing.assert_frame_equal(result, result_with_cache) + + # check that something was written to the cache + assert report._cache != {} + report.clean_cache() + + # check that passing using data outside from the report works and that we they + # don't come from the cache + result_external_data = getattr(report.metrics, metric)( + data_source="X_y", X=X_test, y=y_test + ) + assert isinstance(result_external_data, pd.DataFrame) + pd.testing.assert_frame_equal(result, result_external_data) + assert report._cache != {} + + +@pytest.mark.parametrize("metric", ["r2", "rmse"]) +def test_estimator_report_metrics_regression(regression_data, metric): + """Check the behaviour of the metrics methods available for regression.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, metric) + result = getattr(report.metrics, metric)() + assert isinstance(result, pd.DataFrame) + # check that we hit the cache + result_with_cache = getattr(report.metrics, metric)() + pd.testing.assert_frame_equal(result, result_with_cache) + + # check that something was written to the cache + assert report._cache != {} + report.clean_cache() + + # check that passing using data outside from the report works and that we they + # don't come from the cache + result_external_data = getattr(report.metrics, metric)( + data_source="X_y", X=X_test, y=y_test + ) + assert isinstance(result_external_data, pd.DataFrame) + pd.testing.assert_frame_equal(result, result_external_data) + assert report._cache != {} + + +def _normalize_metric_name(column): + """Helper to normalize the metric name present in a pandas column that could be + a multi-index or single-index.""" + # if we have a multi-index, then the metric name is on level 0 + s = column[0] if isinstance(column, tuple) else column + # Remove spaces and underscores + return re.sub(r"[^a-zA-Z]", "", s.lower()) + + +def _check_results_report_metrics(result, expected_metrics, expected_nb_stats): + assert isinstance(result, pd.DataFrame) + assert len(result.columns) == expected_nb_stats + + normalized_expected = { + _normalize_metric_name(metric) for metric in expected_metrics + } + for column in result.columns: + normalized_column = _normalize_metric_name(column) + matches = [ + metric for metric in normalized_expected if metric == normalized_column + ] + assert len(matches) == 1, ( + f"No match found for column '{column}' in expected metrics: " + f" {expected_metrics}" + ) + + +@pytest.mark.parametrize("pos_label, nb_stats", [(None, 2), (1, 1)]) +def test_estimator_report_report_metrics_binary( + binary_classification_data, binary_classification_data_svc, pos_label, nb_stats +): + """Check the behaviour of the `report_metrics` method with binary + classification. We test both with an SVC that does not support `predict_proba` and a + RandomForestClassifier that does. + """ + estimator, X_test, y_test = binary_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(pos_label=pos_label) + expected_metrics = ("precision", "recall", "roc_auc", "brier_score") + # depending on `pos_label`, we report a stats for each class or not for + # precision and recall + expected_nb_stats = 2 * nb_stats + 2 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + # Repeat the same experiment where we the target labels are not [0, 1] but + # ["neg", "pos"]. We check that we don't get any error. + target_names = np.array(["neg", "pos"], dtype=object) + pos_label_name = target_names[pos_label] if pos_label is not None else pos_label + y_test = target_names[y_test] + estimator = clone(estimator).fit(X_test, y_test) + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(pos_label=pos_label_name) + expected_metrics = ("precision", "recall", "roc_auc", "brier_score") + # depending on `pos_label`, we report a stats for each class or not for + # precision and recall + expected_nb_stats = 2 * nb_stats + 2 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + estimator, X_test, y_test = binary_classification_data_svc + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics(pos_label=pos_label) + expected_metrics = ("precision", "recall", "roc_auc") + # depending on `pos_label`, we report a stats for each class or not for + # precision and recall + expected_nb_stats = 2 * nb_stats + 1 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + +def test_estimator_report_report_metrics_multiclass( + multiclass_classification_data, multiclass_classification_data_svc +): + """Check the behaviour of the `report_metrics` method with multiclass + classification. + """ + estimator, X_test, y_test = multiclass_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics() + expected_metrics = ("precision", "recall", "roc_auc", "log_loss") + # since we are not averaging by default, we report 3 statistics for + # precision, recall and roc_auc + expected_nb_stats = 3 * 3 + 1 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + estimator, X_test, y_test = multiclass_classification_data_svc + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics() + expected_metrics = ("precision", "recall") + # since we are not averaging by default, we report 3 statistics for + # precision and recall + expected_nb_stats = 3 * 2 + _check_results_report_metrics(result, expected_metrics, expected_nb_stats) + + +def test_estimator_report_report_metrics_regression(regression_data): + """Check the behaviour of the `report_metrics` method with regression.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + result = report.metrics.report_metrics() + expected_metrics = ("r2", "rmse") + _check_results_report_metrics(result, expected_metrics, len(expected_metrics)) + + +def test_estimator_report_report_metrics_scoring_kwargs( + regression_multioutput_data, multiclass_classification_data +): + """Check the behaviour of the `report_metrics` method with scoring kwargs.""" + estimator, X_test, y_test = regression_multioutput_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, "report_metrics") + result = report.metrics.report_metrics(scoring_kwargs={"multioutput": "raw_values"}) + assert result.shape == (1, 4) + assert isinstance(result.columns, pd.MultiIndex) + assert result.columns.names == ["Metric", "Output"] + + estimator, X_test, y_test = multiclass_classification_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + assert hasattr(report.metrics, "report_metrics") + result = report.metrics.report_metrics(scoring_kwargs={"average": None}) + assert result.shape == (1, 10) + assert isinstance(result.columns, pd.MultiIndex) + assert result.columns.names == ["Metric", "Class label"] + + +def test_estimator_report_interaction_cache_metrics(regression_multioutput_data): + """Check that the cache take into account the 'kwargs' of a metric.""" + estimator, X_test, y_test = regression_multioutput_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + # The underlying metrics will call `_compute_metric_scores` that take some arbitrary + # kwargs apart from `pos_label`. Let's pass an arbitrary kwarg and make sure it is + # part of the cache. + multioutput = "raw_values" + result_r2_raw_values = report.metrics.r2(multioutput=multioutput) + should_raise = True + for cached_key in report._cache: + if any(item == multioutput for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {multioutput} should be stored in one of the cache keys" + assert result_r2_raw_values.shape == (1, 2) + + multioutput = "uniform_average" + result_r2_uniform_average = report.metrics.r2(multioutput=multioutput) + should_raise = True + for cached_key in report._cache: + if any(item == multioutput for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {multioutput} should be stored in one of the cache keys" + assert result_r2_uniform_average.shape == (1, 1) + + +def test_estimator_report_custom_metric(regression_data): + """Check the behaviour of the `custom_metric` computation in the report.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + def custom_metric(y_true, y_pred, threshold=0.5): + residuals = y_true - y_pred + return np.mean(np.where(residuals < threshold, residuals, 1)) + + threshold = 1 + result = report.metrics.custom_metric( + metric_function=custom_metric, + metric_name="Custom Metric", + response_method="predict", + threshold=threshold, + ) + should_raise = True + for cached_key in report._cache: + if any(item == threshold for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {threshold} should be stored in one of the cache keys" + + assert result.columns.tolist() == ["Custom Metric"] + assert result.to_numpy()[0, 0] == pytest.approx( + custom_metric(y_test, estimator.predict(X_test), threshold) + ) + + threshold = 100 + result = report.metrics.custom_metric( + metric_function=custom_metric, + metric_name="Custom Metric", + response_method="predict", + threshold=threshold, + ) + should_raise = True + for cached_key in report._cache: + if any(item == threshold for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), f"The value {threshold} should be stored in one of the cache keys" + + assert result.columns.tolist() == ["Custom Metric"] + assert result.to_numpy()[0, 0] == pytest.approx( + custom_metric(y_test, estimator.predict(X_test), threshold) + ) + + +def test_estimator_report_custom_function_kwargs_numpy_array(regression_data): + """Check that we are able to store a hash of a numpy array in the cache when they + are passed as kwargs. + """ + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + weights = np.ones_like(y_test) * 2 + hash_weights = joblib.hash(weights) + + def custom_metric(y_true, y_pred, some_weights): + return np.mean((y_true - y_pred) * some_weights) + + result = report.metrics.custom_metric( + metric_function=custom_metric, + metric_name="Custom Metric", + response_method="predict", + some_weights=weights, + ) + should_raise = True + for cached_key in report._cache: + if any(item == hash_weights for item in cached_key): + should_raise = False + break + assert ( + not should_raise + ), "The hash of the weights should be stored in one of the cache keys" + + assert result.columns.tolist() == ["Custom Metric"] + assert result.to_numpy()[0, 0] == pytest.approx( + custom_metric(y_test, estimator.predict(X_test), weights) + ) + + +def test_estimator_report_report_metrics_with_custom_metric(regression_data): + """Check that we can pass a custom metric with specific kwargs into + `report_metrics`.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + weights = np.ones_like(y_test) * 2 + + def custom_metric(y_true, y_pred, some_weights): + return np.mean((y_true - y_pred) * some_weights) + + result = report.metrics.report_metrics( + scoring=["r2", custom_metric], + scoring_kwargs={"some_weights": weights, "response_method": "predict"}, + ) + assert result.shape == (1, 2) + np.testing.assert_allclose( + result.to_numpy(), + [ + [ + r2_score(y_test, estimator.predict(X_test)), + custom_metric(y_test, estimator.predict(X_test), weights), + ] + ], + ) + + +def test_estimator_report_report_metrics_with_scorer(regression_data): + """Check that we can pass scikit-learn scorer with different parameters to + the `report_metrics` method.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + weights = np.ones_like(y_test) * 2 + + def custom_metric(y_true, y_pred, some_weights): + return np.mean((y_true - y_pred) * some_weights) + + median_absolute_error_scorer = make_scorer( + median_absolute_error, response_method="predict" + ) + custom_metric_scorer = make_scorer( + custom_metric, response_method="predict", some_weights=weights + ) + result = report.metrics.report_metrics( + scoring=[r2_score, median_absolute_error_scorer, custom_metric_scorer], + scoring_kwargs={"response_method": "predict"}, # only dispatched to r2_score + ) + assert result.shape == (1, 3) + np.testing.assert_allclose( + result.to_numpy(), + [ + [ + r2_score(y_test, estimator.predict(X_test)), + median_absolute_error(y_test, estimator.predict(X_test)), + custom_metric(y_test, estimator.predict(X_test), weights), + ] + ], + ) + + +def test_estimator_report_report_metrics_invalid_metric_type(regression_data): + """Check that we raise the expected error message if an invalid metric is passed.""" + estimator, X_test, y_test = regression_data + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + err_msg = re.escape("Invalid type of metric: for 1") + with pytest.raises(ValueError, match=err_msg): + report.metrics.report_metrics(scoring=[1]) + + +def test_estimator_report_get_X_y_and_data_source_hash_error(): + """Check that we raise the proper error in `get_X_y_and_use_cache`.""" + X, y = make_classification(n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + estimator = LogisticRegression().fit(X_train, y_train) + report = EstimatorReport(estimator) + + err_msg = re.escape( + "Invalid data source: unknown. Possible values are: " "test, train, X_y." + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source="unknown") + + for data_source in ("train", "test"): + err_msg = re.escape( + f"No {data_source} data (i.e. X_{data_source} and y_{data_source}) were " + f"provided when creating the reporter. Please provide the {data_source} " + "data either when creating the reporter or by setting data_source to " + "'X_y' and providing X and y." + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source=data_source) + + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + for data_source in ("train", "test"): + err_msg = f"X and y must be None when data_source is {data_source}." + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source=data_source, X=X_test, y=y_test) + + err_msg = "X and y must be provided." + with pytest.raises(ValueError, match=err_msg): + report.metrics.log_loss(data_source="X_y") + + # FIXME: once we choose some basic metrics for clustering, then we don't need to + # use `custom_metric` for them. + estimator = KMeans().fit(X_train) + report = EstimatorReport(estimator, X_test=X_test) + err_msg = "X must be provided." + with pytest.raises(ValueError, match=err_msg): + report.metrics.custom_metric( + rand_score, response_method="predict", data_source="X_y" + ) + + report = EstimatorReport(estimator) + for data_source in ("train", "test"): + err_msg = re.escape( + f"No {data_source} data (i.e. X_{data_source}) were provided when " + f"creating the reporter. Please provide the {data_source} data either " + f"when creating the reporter or by setting data_source to 'X_y' and " + f"providing X and y." + ) + with pytest.raises(ValueError, match=err_msg): + report.metrics.custom_metric( + rand_score, response_method="predict", data_source=data_source + ) + + +@pytest.mark.parametrize("data_source", ("train", "test", "X_y")) +def test_estimator_report_get_X_y_and_data_source_hash(data_source): + """Check the general behaviour of `get_X_y_and_use_cache`.""" + X, y = make_classification(n_classes=2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + estimator = LogisticRegression() + report = EstimatorReport( + estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test + ) + + kwargs = {"X": X_test, "y": y_test} if data_source == "X_y" else {} + X, y, data_source_hash = report.metrics._get_X_y_and_data_source_hash( + data_source=data_source, **kwargs + ) + + if data_source == "train": + assert X is X_train + assert y is y_train + assert data_source_hash is None + elif data_source == "test": + assert X is X_test + assert y is y_test + assert data_source_hash is None + elif data_source == "X_y": + assert X is X_test + assert y is y_test + assert data_source_hash == joblib.hash((X_test, y_test)) diff --git a/skore/tests/unit/utils/test_accessors.py b/skore/tests/unit/utils/test_accessors.py new file mode 100644 index 000000000..7e7a8953b --- /dev/null +++ b/skore/tests/unit/utils/test_accessors.py @@ -0,0 +1,60 @@ +import pytest +from skore.externals._pandas_accessors import DirNamesMixin, _register_accessor +from skore.utils._accessor import _check_supported_ml_task + + +def test_register_accessor(): + """Test that an accessor is properly registered and accessible on a class + instance. + """ + + class ParentClass(DirNamesMixin): + pass + + def register_parent_class_accessor(name: str): + """Register an accessor for the ParentClass class.""" + return _register_accessor(name, ParentClass) + + @register_parent_class_accessor("accessor") + class _Accessor: + def __init__(self, parent): + self._parent = parent + + def func(self): + return True + + obj = ParentClass() + assert hasattr(obj, "accessor") + assert isinstance(obj.accessor, _Accessor) + assert obj.accessor.func() + + +def test_check_supported_ml_task(): + """Test that ML task validation accepts supported tasks and rejects unsupported + ones. + """ + + class MockParent: + def __init__(self, ml_task): + self._ml_task = ml_task + + class MockAccessor: + def __init__(self, parent): + self._parent = parent + + parent = MockParent("binary-classification") + accessor = MockAccessor(parent) + check = _check_supported_ml_task( + ["binary-classification", "multiclass-classification"] + ) + assert check(accessor) + + parent = MockParent("multiclass-classification") + accessor = MockAccessor(parent) + assert check(accessor) + + parent = MockParent("regression") + accessor = MockAccessor(parent) + err_msg = "The regression task is not a supported task by function called." + with pytest.raises(AttributeError, match=err_msg): + check(accessor) diff --git a/sphinx/_templates/autosummary/accessor.rst b/sphinx/_templates/autosummary/accessor.rst new file mode 100644 index 000000000..145ca83dd --- /dev/null +++ b/sphinx/_templates/autosummary/accessor.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessor:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/sphinx/_templates/autosummary/accessor_attribute.rst b/sphinx/_templates/autosummary/accessor_attribute.rst new file mode 100644 index 000000000..c2769d66d --- /dev/null +++ b/sphinx/_templates/autosummary/accessor_attribute.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorattribute:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/sphinx/_templates/autosummary/accessor_callable.rst b/sphinx/_templates/autosummary/accessor_callable.rst new file mode 100644 index 000000000..261adfdf1 --- /dev/null +++ b/sphinx/_templates/autosummary/accessor_callable.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorcallable:: {{ (module.split('.')[1:] + [objname]) | join('.') }}.__call__ diff --git a/sphinx/_templates/autosummary/accessor_method.rst b/sphinx/_templates/autosummary/accessor_method.rst new file mode 100644 index 000000000..5c116571d --- /dev/null +++ b/sphinx/_templates/autosummary/accessor_method.rst @@ -0,0 +1,5 @@ +{{ objname | escape | underline(line="=") }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessormethod:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/sphinx/api.rst b/sphinx/api.rst index 4c3b2535c..e1a519531 100644 --- a/sphinx/api.rst +++ b/sphinx/api.rst @@ -36,3 +36,59 @@ These functions and classes enhance scikit-learn's ones. train_test_split CrossValidationReporter item.cross_validation_item.CrossValidationItem + +Report for a single estimator +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The class :class:`EstimatorReport` provides a reporter allowing to inspect and +evaluate a scikit-learn estimator in an interactive way. The functionalities of the +reporter are accessible through accessors. + +.. autosummary:: + :toctree: generated/ + :template: base.rst + :caption: Reporting for a single estimator + + EstimatorReport + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: autosummary/accessor_method.rst + + EstimatorReport.help + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: autosummary/accessor.rst + + EstimatorReport.metrics + +Metrics +""""""" + +The `metrics` accessor helps you to evaluate the statistical performance of your +estimator. In addition, we provide a sub-accessor `plot`, to get the common +performance metric representations. + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: autosummary/accessor_method.rst + + EstimatorReport.metrics.help + EstimatorReport.metrics.report_metrics + EstimatorReport.metrics.custom_metric + EstimatorReport.metrics.accuracy + EstimatorReport.metrics.brier_score + EstimatorReport.metrics.log_loss + EstimatorReport.metrics.precision + EstimatorReport.metrics.r2 + EstimatorReport.metrics.recall + EstimatorReport.metrics.rmse + EstimatorReport.metrics.roc_auc + EstimatorReport.metrics.plot.help + EstimatorReport.metrics.plot.precision_recall + EstimatorReport.metrics.plot.prediction_error + EstimatorReport.metrics.plot.roc diff --git a/sphinx/conf.py b/sphinx/conf.py index b22954ba3..7d554848e 100644 --- a/sphinx/conf.py +++ b/sphinx/conf.py @@ -7,6 +7,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os +import sphinx_autosummary_accessors from sphinx_gallery.sorting import ExplicitOrder project = "skore" @@ -27,11 +28,16 @@ "sphinx_gallery.gen_gallery", "sphinx_copybutton", "sphinx_tabs.tabs", + "sphinx_autosummary_accessors", ] -templates_path = ["_templates"] exclude_patterns = ["build", "Thumbs.db", ".DS_Store"] +# The reST default role (used for this markup: `text`) to use for all +# documents. +default_role = "literal" + # Add any paths that contain templates here, relative to this directory. +autosummary_generate = True # generate stubs for all classes templates_path = ["_templates"] # -- Options for HTML output ------------------------------------------------- @@ -107,7 +113,7 @@ # Use :html_theme.sidebar_secondary.remove: for file-wide removal "secondary_sidebar_items": { "**": ["page-toc", "sourcelink", "sg_download_links", "sg_launcher_links"], - "index": [], # hide secondary sidebar items for the landing page + "index": [], # hide secondary sidebar items for the landing page "install": [], }, "external_links": [