From e6d2a69f9bb3a522af524d4e11a034f0a20f3a9e Mon Sep 17 00:00:00 2001
From: Guillaume Lemaitre <guillaume@probabl.ai>
Date: Wed, 8 Jan 2025 15:25:59 +0100
Subject: [PATCH] split stubs file

---
 .../src/skore/sklearn/_estimator/__init__.pyi |   3 +
 skore/src/skore/sklearn/_estimator/base.pyi   |  33 ++++
 .../skore/sklearn/_estimator/estimator.pyi    | 174 ------------------
 .../sklearn/_estimator/metrics_accessor.pyi   | 168 +++++++++++++++++
 skore/src/skore/sklearn/_estimator/report.pyi |  71 +++++++
 5 files changed, 275 insertions(+), 174 deletions(-)
 create mode 100644 skore/src/skore/sklearn/_estimator/__init__.pyi
 create mode 100644 skore/src/skore/sklearn/_estimator/base.pyi
 delete mode 100644 skore/src/skore/sklearn/_estimator/estimator.pyi
 create mode 100644 skore/src/skore/sklearn/_estimator/metrics_accessor.pyi
 create mode 100644 skore/src/skore/sklearn/_estimator/report.pyi

diff --git a/skore/src/skore/sklearn/_estimator/__init__.pyi b/skore/src/skore/sklearn/_estimator/__init__.pyi
new file mode 100644
index 000000000..f496ff1f3
--- /dev/null
+++ b/skore/src/skore/sklearn/_estimator/__init__.pyi
@@ -0,0 +1,3 @@
+from skore.sklearn._estimator.report import EstimatorReport
+
+__all__ = ["EstimatorReport"]
diff --git a/skore/src/skore/sklearn/_estimator/base.pyi b/skore/src/skore/sklearn/_estimator/base.pyi
new file mode 100644
index 000000000..eca62e85e
--- /dev/null
+++ b/skore/src/skore/sklearn/_estimator/base.pyi
@@ -0,0 +1,33 @@
+from typing import Any, Literal, Optional
+
+import numpy as np
+from rich.panel import Panel
+from rich.tree import Tree
+
+class _HelpMixin:
+    def _get_methods_for_help(self) -> list[tuple[str, Any]]: ...
+    def _sort_methods_for_help(
+        self, methods: list[tuple[str, Any]]
+    ) -> list[tuple[str, Any]]: ...
+    def _format_method_name(self, name: str) -> str: ...
+    def _get_method_description(self, method: Any) -> str: ...
+    def _create_help_panel(self) -> Panel: ...
+    def _get_help_panel_title(self) -> str: ...
+    def _create_help_tree(self) -> Tree: ...
+    def help(self) -> None: ...
+    def __repr__(self) -> str: ...
+
+class _BaseAccessor(_HelpMixin):
+    _parent: Any
+    _icon: str
+
+    def __init__(self, parent: Any, icon: str) -> None: ...
+    def _get_help_panel_title(self) -> str: ...
+    def _create_help_tree(self) -> Tree: ...
+    def _get_X_y_and_data_source_hash(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"],
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+    ) -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[str]]: ...
diff --git a/skore/src/skore/sklearn/_estimator/estimator.pyi b/skore/src/skore/sklearn/_estimator/estimator.pyi
deleted file mode 100644
index a7da48927..000000000
--- a/skore/src/skore/sklearn/_estimator/estimator.pyi
+++ /dev/null
@@ -1,174 +0,0 @@
-from typing import Any, Literal, Optional, Union
-
-import matplotlib.axes
-import pandas as pd
-from numpy import ndarray
-from sklearn.base import BaseEstimator
-from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
-
-from skore.sklearn._plot import PredictionErrorDisplay
-
-class _BaseAccessor:
-    _parent: EstimatorReport
-    def __init__(self, parent: EstimatorReport, icon: str) -> None: ...
-    def help(self) -> None: ...
-
-class _PlotMetricsAccessor(_BaseAccessor):
-    def roc(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        pos_label: Optional[Union[str, int]] = None,
-        ax: Optional[matplotlib.axes.Axes] = None,
-        name: Optional[str] = None,
-    ) -> RocCurveDisplay: ...
-    def precision_recall(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        pos_label: Optional[Union[str, int]] = None,
-        ax: Optional[matplotlib.axes.Axes] = None,
-        name: Optional[str] = None,
-    ) -> PrecisionRecallDisplay: ...
-    def prediction_error(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        ax: Optional[matplotlib.axes.Axes] = None,
-        kind: Literal[
-            "actual_vs_predicted", "residual_vs_predicted"
-        ] = "residual_vs_predicted",
-        subsample: Optional[Union[int, float]] = 1_000,
-    ) -> PredictionErrorDisplay: ...
-
-class _MetricsAccessor(_BaseAccessor):
-    plot: _PlotMetricsAccessor
-
-    def report_metrics(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        scoring: Optional[Union[list[str], callable]] = None,
-        pos_label: int = 1,
-        scoring_kwargs: Optional[dict] = None,
-    ) -> pd.DataFrame: ...
-    def custom_metric(
-        self,
-        metric_function: callable,
-        response_method: Union[str, list[str]],
-        *,
-        metric_name: Optional[str] = None,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        **kwargs: Any,
-    ) -> pd.DataFrame: ...
-    def accuracy(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-    ) -> pd.DataFrame: ...
-    def precision(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        average: Literal[
-            "auto", "macro", "micro", "weighted", "samples", None
-        ] = "auto",
-        pos_label: Optional[Union[str, int]] = None,
-    ) -> pd.DataFrame: ...
-    def recall(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        average: Literal[
-            "auto", "macro", "micro", "weighted", "samples", None
-        ] = "auto",
-        pos_label: Optional[Union[str, int]] = None,
-    ) -> pd.DataFrame: ...
-    def brier_score(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        pos_label: int = 1,
-    ) -> pd.DataFrame: ...
-    def roc_auc(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        average: Literal["auto", "macro", "micro", "weighted", "samples"] = "auto",
-        multi_class: Literal["raise", "ovr", "ovo", "auto"] = "ovr",
-    ) -> pd.DataFrame: ...
-    def log_loss(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-    ) -> pd.DataFrame: ...
-    def r2(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        multioutput: Union[
-            Literal["raw_values", "uniform_average"], ndarray
-        ] = "uniform_average",
-    ) -> pd.DataFrame: ...
-    def rmse(
-        self,
-        *,
-        data_source: Literal["test", "train", "X_y"] = "test",
-        X: Optional[ndarray] = None,
-        y: Optional[ndarray] = None,
-        multioutput: Union[
-            Literal["raw_values", "uniform_average"], ndarray
-        ] = "uniform_average",
-    ) -> pd.DataFrame: ...
-
-class EstimatorReport:
-    metrics: _MetricsAccessor
-
-    def __init__(
-        self,
-        estimator: BaseEstimator,
-        *,
-        fit: Literal["auto", True, False] = "auto",
-        X_train: Optional[ndarray] = None,
-        y_train: Optional[ndarray] = None,
-        X_test: Optional[ndarray] = None,
-        y_test: Optional[ndarray] = None,
-    ) -> None: ...
-    def clean_cache(self) -> None: ...
-    def cache_predictions(self, response_methods="auto", n_jobs=None) -> None: ...
-    def __repr__(self) -> str: ...
-    @property
-    def estimator(self) -> BaseEstimator: ...
-    @property
-    def X_train(self) -> Optional[ndarray]: ...
-    @property
-    def y_train(self) -> Optional[ndarray]: ...
-    @property
-    def X_test(self) -> Optional[ndarray]: ...
-    @property
-    def y_test(self) -> Optional[ndarray]: ...
-    def help(self) -> None: ...
diff --git a/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi b/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi
new file mode 100644
index 000000000..c73d136f1
--- /dev/null
+++ b/skore/src/skore/sklearn/_estimator/metrics_accessor.pyi
@@ -0,0 +1,168 @@
+from typing import Any, Callable, Literal, Optional, Union
+
+import matplotlib.axes
+import numpy as np
+import pandas as pd
+from sklearn.metrics import PrecisionRecallDisplay, RocCurveDisplay
+
+from skore.sklearn._estimator.base import _BaseAccessor
+from skore.sklearn._plot import PredictionErrorDisplay
+
+class _PlotMetricsAccessor(_BaseAccessor):
+    _metrics_parent: _MetricsAccessor
+
+    def __init__(self, parent: _MetricsAccessor) -> None: ...
+    def _get_display(
+        self,
+        *,
+        X: Optional[np.ndarray],
+        y: Optional[np.ndarray],
+        data_source: Literal["test", "train", "X_y"],
+        response_method: Union[str, list[str]],
+        display_class: Any,
+        display_kwargs: dict[str, Any],
+        display_plot_kwargs: dict[str, Any],
+    ) -> Union[RocCurveDisplay, PrecisionRecallDisplay, PredictionErrorDisplay]: ...
+    def roc(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        pos_label: Optional[Union[str, int]] = None,
+        ax: Optional[matplotlib.axes.Axes] = None,
+    ) -> RocCurveDisplay: ...
+    def precision_recall(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        pos_label: Optional[Union[str, int]] = None,
+        ax: Optional[matplotlib.axes.Axes] = None,
+    ) -> PrecisionRecallDisplay: ...
+    def prediction_error(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        ax: Optional[matplotlib.axes.Axes] = None,
+        kind: Literal[
+            "actual_vs_predicted", "residual_vs_predicted"
+        ] = "residual_vs_predicted",
+        subsample: Optional[Union[int, float]] = 1_000,
+    ) -> PredictionErrorDisplay: ...
+
+class _MetricsAccessor(_BaseAccessor):
+    _SCORE_OR_LOSS_ICONS: dict[str, str]
+    plot: _PlotMetricsAccessor
+
+    def _compute_metric_scores(
+        self,
+        metric_fn: Callable,
+        X: Optional[np.ndarray],
+        y_true: Optional[np.ndarray],
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        response_method: Union[str, list[str]],
+        pos_label: Optional[Union[str, int]] = None,
+        metric_name: Optional[str] = None,
+        **metric_kwargs: Any,
+    ) -> pd.DataFrame: ...
+    def report_metrics(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        scoring: Optional[Union[list[str], Callable]] = None,
+        pos_label: Optional[Union[str, int]] = None,
+        scoring_kwargs: Optional[dict[str, Any]] = None,
+    ) -> pd.DataFrame: ...
+    def accuracy(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+    ) -> pd.DataFrame: ...
+    def precision(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        average: Optional[
+            Literal["binary", "micro", "macro", "weighted", "samples"]
+        ] = None,
+        pos_label: Optional[Union[str, int]] = None,
+    ) -> pd.DataFrame: ...
+    def recall(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        average: Optional[
+            Literal["binary", "micro", "macro", "weighted", "samples"]
+        ] = None,
+        pos_label: Optional[Union[str, int]] = None,
+    ) -> pd.DataFrame: ...
+    def brier_score(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        pos_label: Optional[Union[str, int]] = None,
+    ) -> pd.DataFrame: ...
+    def roc_auc(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        average: Optional[
+            Literal["auto", "micro", "macro", "weighted", "samples"]
+        ] = None,
+        multi_class: Literal["raise", "ovr", "ovo"] = "ovr",
+    ) -> pd.DataFrame: ...
+    def log_loss(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+    ) -> pd.DataFrame: ...
+    def r2(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        multioutput: Union[
+            Literal["raw_values", "uniform_average"], np.ndarray
+        ] = "raw_values",
+    ) -> pd.DataFrame: ...
+    def rmse(
+        self,
+        *,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        multioutput: Union[
+            Literal["raw_values", "uniform_average"], np.ndarray
+        ] = "raw_values",
+    ) -> pd.DataFrame: ...
+    def custom_metric(
+        self,
+        metric_function: Callable,
+        response_method: Union[str, list[str]],
+        *,
+        metric_name: Optional[str] = None,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        X: Optional[np.ndarray] = None,
+        y: Optional[np.ndarray] = None,
+        **kwargs: Any,
+    ) -> pd.DataFrame: ...
diff --git a/skore/src/skore/sklearn/_estimator/report.pyi b/skore/src/skore/sklearn/_estimator/report.pyi
new file mode 100644
index 000000000..74c4c215d
--- /dev/null
+++ b/skore/src/skore/sklearn/_estimator/report.pyi
@@ -0,0 +1,71 @@
+from typing import Any, Literal, Optional, Union
+
+import numpy as np
+from sklearn.base import BaseEstimator
+
+from skore.sklearn._estimator.base import _HelpMixin
+from skore.sklearn._estimator.metrics_accessor import _MetricsAccessor
+
+class EstimatorReport(_HelpMixin):
+    _ACCESSOR_CONFIG: dict[str, dict[str, str]]
+    _estimator: BaseEstimator
+    _X_train: Optional[np.ndarray]
+    _y_train: Optional[np.ndarray]
+    _X_test: Optional[np.ndarray]
+    _y_test: Optional[np.ndarray]
+    _rng: np.random.Generator
+    _hash: int
+    _cache: dict[Any, Any]
+    _ml_task: str
+    metrics: _MetricsAccessor
+
+    @staticmethod
+    def _fit_estimator(
+        estimator: BaseEstimator, X_train: np.ndarray, y_train: Optional[np.ndarray]
+    ) -> BaseEstimator: ...
+    def __init__(
+        self,
+        estimator: BaseEstimator,
+        *,
+        fit: Literal["auto", True, False] = "auto",
+        X_train: Optional[np.ndarray] = None,
+        y_train: Optional[np.ndarray] = None,
+        X_test: Optional[np.ndarray] = None,
+        y_test: Optional[np.ndarray] = None,
+    ) -> None: ...
+    def _initialize_state(self) -> None: ...
+    def clean_cache(self) -> None: ...
+    def cache_predictions(
+        self,
+        response_methods: Union[Literal["auto"], list[str]] = "auto",
+        n_jobs: Optional[int] = None,
+    ) -> None: ...
+    @property
+    def estimator(self) -> BaseEstimator: ...
+    @property
+    def X_train(self) -> Optional[np.ndarray]: ...
+    @property
+    def y_train(self) -> Optional[np.ndarray]: ...
+    @property
+    def X_test(self) -> Optional[np.ndarray]: ...
+    @X_test.setter
+    def X_test(self, value: Optional[np.ndarray]) -> None: ...
+    @property
+    def y_test(self) -> Optional[np.ndarray]: ...
+    @y_test.setter
+    def y_test(self, value: Optional[np.ndarray]) -> None: ...
+    @property
+    def estimator_name(self) -> str: ...
+    def _get_cached_response_values(
+        self,
+        *,
+        estimator_hash: int,
+        estimator: BaseEstimator,
+        X: np.ndarray,
+        response_method: Union[str, list[str]],
+        pos_label: Optional[Union[str, int]] = None,
+        data_source: Literal["test", "train", "X_y"] = "test",
+        data_source_hash: Optional[str] = None,
+    ) -> np.ndarray: ...
+    def _get_help_panel_title(self) -> str: ...
+    def _create_help_tree(self) -> Any: ...  # Returns rich.tree.Tree