From 4d08a193321bbaa1e9d7a31e7868a76cbd24df5d Mon Sep 17 00:00:00 2001 From: Thomas S Date: Mon, 13 Jan 2025 17:06:18 +0100 Subject: [PATCH] CrossValidationReporterItem's factory now pickles the reporter --- .../item/cross_validation_reporter_item.py | 369 +++++++----------- .../sklearn/test_cross_validate.py | 4 +- skore/tests/integration/ui/test_ui.py | 5 +- ...=> test_cross_validation_reporter_item.py} | 58 ++- 4 files changed, 172 insertions(+), 264 deletions(-) rename skore/tests/unit/item/{test_cross_validation_item.py => test_cross_validation_reporter_item.py} (80%) diff --git a/skore/src/skore/persistence/item/cross_validation_reporter_item.py b/skore/src/skore/persistence/item/cross_validation_reporter_item.py index 23c81dab9..263e1528c 100644 --- a/skore/src/skore/persistence/item/cross_validation_reporter_item.py +++ b/skore/src/skore/persistence/item/cross_validation_reporter_item.py @@ -1,20 +1,21 @@ -"""CrossValidationItem class. +"""CrossValidationReporterItem. -This class represents the output of a cross-validation workflow. +This module defines the CrossValidationReporterItem class, which is used to persist +reporters of cross-validation. """ from __future__ import annotations import contextlib -import copy import dataclasses import hashlib import importlib import json +import pickle import re import statistics from functools import cached_property -from typing import TYPE_CHECKING, Any, Literal, TypedDict, Union +from typing import TYPE_CHECKING, Literal, Optional, TypedDict import numpy import plotly.graph_objects @@ -27,8 +28,6 @@ if TYPE_CHECKING: import sklearn.base - CVSplitter = Any - class EstimatorParamInfo(TypedDict): """Information about an estimator parameter.""" @@ -43,6 +42,12 @@ class EstimatorInfo(TypedDict): params: dict[str, EstimatorParamInfo] +HUMANIZED_PLOT_NAMES = { + "scores": "Scores", + "timing": "Timings", +} + + def _hash_numpy(arr: numpy.ndarray) -> str: """Compute a hash string from a numpy array. @@ -117,86 +122,122 @@ def _params_to_str(estimator_info) -> str: return "\n".join(params_list) -# Data used for training, passed as input to scikit-learn -Data = Any -# Target used for training, passed as input to scikit-learn -Target = Any +def _estimator_info(estimator: sklearn.base.BaseEstimator) -> EstimatorInfo: + estimator_params = ( + estimator.get_params() if hasattr(estimator, "get_params") else {} + ) + + name = estimator.__class__.__name__ + module = estimator.__module__ + + # Figure out the default parameters of the estimator, + # so that we can highlight the non-default ones in the UI + + # This is done by instantiating the class with no arguments and + # computing the diff between the default and ours + try: + estimator_module = importlib.import_module(module) + EstimatorClass = getattr(estimator_module, name) + default_estimator_params = EstimatorClass().get_params() + except Exception: + default_estimator_params = {} + + final_estimator_params: dict[str, EstimatorParamInfo] = {} + for k, v in estimator_params.items(): + param_is_default: bool = ( + k in default_estimator_params and default_estimator_params[k] == v + ) + final_estimator_params[str(k)] = { + "value": repr(v), + "default": param_is_default, + } + return { + "name": name, + "module": module, + "params": final_estimator_params, + } -class CrossValidationItem(Item): - """ - A class to represent the output of a cross-validation workflow. - This class encapsulates the output of the - :func:`sklearn.model_selection.cross_validate` function along with its creation and - update timestamps. - """ +class CrossValidationReporterItem(Item): + """Class to persist the reporter of cross-validation.""" def __init__( self, - cv_results_serialized: dict, - estimator_info: EstimatorInfo, - X_info: dict, - y_info: Union[dict, None], - plots_bytes: dict[str, bytes], - cv_info: dict, - created_at: Union[str, None] = None, - updated_at: Union[str, None] = None, + reporter_bytes: bytes, + created_at: Optional[str] = None, + updated_at: Optional[str] = None, ): """ - Initialize a CrossValidationItem. + Initialize a CrossValidationReporterItem. Parameters ---------- - cv_results_serialized : dict - The dict output of the :func:`sklearn.model_selection.cross_validate` - function, in a form suitable for serialization. - estimator_info : dict - The estimator that was cross-validated. - X_info : dict - A summary of the data, input of the - :func:`sklearn.model_selection.cross_validate` function. - y_info : dict - A summary of the target, input of the - :func:`sklearn.model_selection.cross_validate` function. - plots_bytes : dict[str, bytes] - A collection of plots of the cross-validation results, in the form of bytes. - cv_info: dict - A dict containing cross validation splitting strategy params. - created_at : str + reporter_bytes : bytes + The raw bytes of the reporter pickle representation. + created_at : str, optional The creation timestamp in ISO format. - updated_at : str + updated_at : str, optional The last update timestamp in ISO format. """ super().__init__(created_at, updated_at) - self.cv_results_serialized = cv_results_serialized - self.estimator_info = estimator_info - self.X_info = X_info - self.y_info = y_info - self.plots_bytes = plots_bytes - self.cv_info = cv_info + self.reporter_bytes = reporter_bytes - def as_serializable_dict(self): - """Get a serializable dict from the item. + @classmethod + def factory(cls, reporter: CrossValidationReporter) -> CrossValidationReporterItem: + """ + Create a CrossValidationReporterItem instance from a CrossValidationReporter. - Derived class must call their super implementation - and merge the result with their output. + Parameters + ---------- + reporter : CrossValidationReporter + + Returns + ------- + CrossValidationReporterItem + A new CrossValidationReporterItem instance. """ + if not isinstance(reporter, CrossValidationReporter): + raise ItemTypeError(f"Type '{reporter.__class__}' is not supported.") + + instance = cls(pickle.dumps(reporter)) + + # add reporter as cached property + instance.reporter = reporter + + return instance + + @cached_property + def reporter(self) -> CrossValidationReporter: + """The CrossValidationReporter from the persistence.""" + return pickle.loads(self.reporter_bytes) + + def as_serializable_dict(self): + """Get a serializable dict from the item.""" # Get tabular results (the cv results in a dataframe-like structure) - cv_results = copy.deepcopy(self.cv_results_serialized) - cv_results.pop("indices", None) - - metrics_names = list(cv_results.keys()) - tabular_results = { - "name": "Cross validation results", - "columns": metrics_names, - "data": list(zip(*cv_results.values())), - "favorability": [_metric_favorability(m) for m in metrics_names], + cv_results = { + key: value.tolist() + for key, value in self.reporter.cv_results.items() + if ( + key != "estimator" + and key != "indices" + and isinstance(value, numpy.ndarray) + ) } + metrics_names = list(cv_results) + tabular_results = [ + { + "name": "Cross validation results", + "columns": metrics_names, + "data": list(zip(*cv_results.values())), + "favorability": [_metric_favorability(m) for m in metrics_names], + } + ] + # Get scalar results (summary statistics of the cv results) - mean_cv_results = [ + scalar_results = [ { "name": _metric_title(k), "value": statistics.mean(v), @@ -206,163 +247,22 @@ def as_serializable_dict(self): for k, v in cv_results.items() ] - scalar_results = mean_cv_results - - params_as_str = _params_to_str(self.estimator_info) - # If the estimator is from sklearn, make the class name a hyperlink # to the relevant docs - name = self.estimator_info["name"] - module = re.sub(r"\.\_.+", "", self.estimator_info["module"]) + estimator_info = _estimator_info(self.reporter.estimator) + name = estimator_info["name"] + module = re.sub(r"\.\_.+", "", estimator_info["module"]) if module.startswith("sklearn"): doc_url = f"https://scikit-learn.org/stable/modules/generated/{module}.{name}.html" doc_link = f'{name}' else: doc_link = f"`{name}`" + params_as_str = _params_to_str(estimator_info) estimator_params_as_str = f"{doc_link}\n{params_as_str}" - # Get cross-validation details - cv_params_as_str = ", ".join(f"{k}: *{v}*" for k, v in self.cv_info.items()) - - r = super().as_serializable_dict() - sections = [ - { - "title": "Model", - "icon": "icon-square-cursor", - "items": [ - { - "name": "Estimator parameters", - "description": "Core model configuration used for training", - "value": estimator_params_as_str, - }, - { - "name": "Cross-validation parameters", - "description": "Controls how data is split and validated", - "value": cv_params_as_str, - }, - ], - } - ] - value = { - "scalar_results": scalar_results, - "tabular_results": [tabular_results], - "plots": [ - { - "name": plot_name, - "value": json.loads(plot_bytes.decode("utf-8")), - } - for plot_name, plot_bytes in self.plots_bytes.items() - ], - "sections": sections, - } - r.update( - { - "media_type": "application/vnd.skore.cross_validation+json", - "value": value, - } - ) - return r - - @staticmethod - def _estimator_info(estimator: sklearn.base.BaseEstimator) -> EstimatorInfo: - estimator_params = ( - estimator.get_params() if hasattr(estimator, "get_params") else {} - ) - - name = estimator.__class__.__name__ - module = estimator.__module__ - - # Figure out the default parameters of the estimator, - # so that we can highlight the non-default ones in the UI - - # This is done by instantiating the class with no arguments and - # computing the diff between the default and ours - try: - estimator_module = importlib.import_module(module) - EstimatorClass = getattr(estimator_module, name) - default_estimator_params = EstimatorClass().get_params() - except Exception: - default_estimator_params = {} - - final_estimator_params: dict[str, EstimatorParamInfo] = {} - for k, v in estimator_params.items(): - param_is_default: bool = ( - k in default_estimator_params and default_estimator_params[k] == v - ) - final_estimator_params[str(k)] = { - "value": repr(v), - "default": param_is_default, - } - - return { - "name": name, - "module": module, - "params": final_estimator_params, - } - - @classmethod - def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem: - """ - Create a new CrossValidationItem instance from a CrossValidationReporter. - - Parameters - ---------- - reporter : CrossValidationReporter - - Returns - ------- - CrossValidationItem - A new CrossValidationItem instance. - """ - if not isinstance(reporter, CrossValidationReporter): - raise ItemTypeError( - f"Type '{reporter.__class__}' is not supported, " - f"only '{CrossValidationReporter.__name__}' is." - ) - - cv_results = reporter._cv_results - estimator = reporter.estimator - X = reporter.X - y = reporter.y - plots = reporter.plots - cv = reporter.cv - - cv_results_serialized = {} - for k, v in cv_results.items(): - if k == "estimator": - continue - if k == "indices": - cv_results_serialized["indices"] = { - "train": tuple(arr.tolist() for arr in v["train"]), - "test": tuple(arr.tolist() for arr in v["test"]), - } - if isinstance(v, numpy.ndarray): - cv_results_serialized[k] = v.tolist() - - estimator_info = CrossValidationItem._estimator_info(estimator) - - y_array = y if isinstance(y, numpy.ndarray) else numpy.array(y) - y_info = None if y is None else {"hash": _hash_numpy(y_array)} - - X_array = X if isinstance(X, numpy.ndarray) else numpy.array(X) - X_info = { - "nb_rows": X_array.shape[0], - "nb_cols": X_array.shape[1], - "hash": _hash_numpy(X_array), - } - - humanized_plot_names = { - "scores": "Scores", - "timing": "Timings", - } - plots_bytes = { - humanized_plot_names[plot_name]: ( - plotly.io.to_json(plot, engine="json").encode("utf-8") - ) - for plot_name, plot in dataclasses.asdict(plots).items() - } - + # + cv = self.reporter.cv cv_info: dict[str, str] = {} if isinstance(cv, int): cv_info["n_splits"] = repr(cv) @@ -376,19 +276,40 @@ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem: attr = getattr(cv, attr_name) cv_info[attr_name] = repr(attr) - return cls( - cv_results_serialized=cv_results_serialized, - estimator_info=estimator_info, - X_info=X_info, - y_info=y_info, - plots_bytes=plots_bytes, - cv_info=cv_info, - ) + cv_params_as_str = ", ".join(f"{k}: *{v}*" for k, v in cv_info.items()) - @cached_property - def plots(self) -> dict: - """Various plots of the cross-validation results.""" - return { - name: plotly.io.from_json(plot_bytes.decode("utf-8")) - for name, plot_bytes in self.plots_bytes.items() + # + value = { + "scalar_results": scalar_results, + "tabular_results": tabular_results, + "plots": [ + { + "name": HUMANIZED_PLOT_NAMES[plot_name], + "value": json.loads(plotly.io.to_json(plot, engine="json")), + } + for plot_name, plot in dataclasses.asdict(self.reporter.plots).items() + ], + "sections": [ + { + "title": "Model", + "icon": "icon-square-cursor", + "items": [ + { + "name": "Estimator parameters", + "description": "Core model configuration used for training", + "value": estimator_params_as_str, + }, + { + "name": "Cross-validation parameters", + "description": "Controls how data is split and validated", + "value": cv_params_as_str, + }, + ], + } + ], + } + + return super().as_serializable_dict() | { + "media_type": "application/vnd.skore.cross_validation+json", + "value": value, } diff --git a/skore/tests/integration/sklearn/test_cross_validate.py b/skore/tests/integration/sklearn/test_cross_validate.py index e39f6ba58..00758f751 100644 --- a/skore/tests/integration/sklearn/test_cross_validate.py +++ b/skore/tests/integration/sklearn/test_cross_validate.py @@ -10,7 +10,7 @@ from sklearn.multiclass import OneVsOneClassifier from sklearn.svm import SVC from skore import CrossValidationReporter -from skore.persistence.item.cross_validation_item import CrossValidationItem +from skore.persistence.item import CrossValidationReporterItem from skore.sklearn.cross_validation.cross_validation_helpers import _get_scorers_to_add @@ -201,7 +201,7 @@ def test_cross_validation_reporter(in_memory_project, fixture_name, request): in_memory_project.put("cross-validation", reporter) retrieved_item = in_memory_project.item_repository.get_item("cross-validation") - assert isinstance(retrieved_item, CrossValidationItem) + assert isinstance(retrieved_item, CrossValidationReporterItem) @pytest.mark.parametrize( diff --git a/skore/tests/integration/ui/test_ui.py b/skore/tests/integration/ui/test_ui.py index 360761a6f..44b15d7a3 100644 --- a/skore/tests/integration/ui/test_ui.py +++ b/skore/tests/integration/ui/test_ui.py @@ -167,7 +167,7 @@ def _fake_cross_validate(*args, **kwargs): monkeypatch.setattr("sklearn.model_selection.cross_validate", _fake_cross_validate) -def test_serialize_cross_validation_item( +def test_serialize_cross_validation_reporter_item( client, in_memory_project, monkeypatch, @@ -176,9 +176,6 @@ def test_serialize_cross_validation_item( fake_cross_validate, ): monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) - monkeypatch.setattr( - "skore.persistence.item.cross_validation_item.CrossValidationItem.plots", {} - ) monkeypatch.setattr( "skore.sklearn.cross_validation.cross_validation_reporter.plot_cross_validation_compare_scores", lambda _: {}, diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_reporter_item.py similarity index 80% rename from skore/tests/unit/item/test_cross_validation_item.py rename to skore/tests/unit/item/test_cross_validation_reporter_item.py index 909eff872..973fc933d 100644 --- a/skore/tests/unit/item/test_cross_validation_item.py +++ b/skore/tests/unit/item/test_cross_validation_reporter_item.py @@ -1,13 +1,13 @@ from dataclasses import dataclass +from pickle import dumps import numpy import plotly.graph_objects import pytest from sklearn.model_selection import StratifiedKFold from skore.persistence.item import ItemTypeError -from skore.persistence.item.cross_validation_item import ( - CrossValidationItem, - _hash_numpy, +from skore.persistence.item.cross_validation_reporter_item import ( + CrossValidationReporterItem, _metric_favorability, ) from skore.sklearn.cross_validation import CrossValidationReporter @@ -27,7 +27,7 @@ class FakeEstimatorNoGetParams: @dataclass class FakeCrossValidationReporter(CrossValidationReporter): - _cv_results = { + cv_results = { "test_score": numpy.array([1, 2, 3]), "estimator": [FakeEstimator(), FakeEstimator(), FakeEstimator()], "fit_time": [1, 2, 3], @@ -44,7 +44,7 @@ class FakeCrossValidationReporter(CrossValidationReporter): @dataclass class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): - _cv_results = { + cv_results = { "test_score": numpy.array([1, 2, 3]), "estimator": [ FakeEstimatorNoGetParams(), @@ -63,14 +63,14 @@ class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): cv = StratifiedKFold(n_splits=5) -class TestCrossValidationItem: +class TestCrossValidationReporterItem: @pytest.fixture(autouse=True) def monkeypatch_datetime(self, monkeypatch, MockDatetime): monkeypatch.setattr("skore.persistence.item.item.datetime", MockDatetime) def test_factory_exception(self): with pytest.raises(ItemTypeError): - CrossValidationItem.factory(None) + CrossValidationReporterItem.factory(None) @pytest.mark.parametrize( "reporter", @@ -82,42 +82,32 @@ def test_factory_exception(self): ], ) def test_factory(self, mock_nowstr, reporter): - item = CrossValidationItem.factory(reporter) - - assert item.cv_results_serialized == {"test_score": [1, 2, 3]} - assert item.estimator_info == { - "name": reporter.estimator.__class__.__name__, - "params": ( - {} - if isinstance(reporter.estimator, FakeEstimatorNoGetParams) - else {"alpha": {"value": "3", "default": True}} - ), - "module": "tests.unit.item.test_cross_validation_item", - } - assert item.X_info == { - "nb_cols": 1, - "nb_rows": 1, - "hash": _hash_numpy(FakeCrossValidationReporter.X), - } - assert item.y_info == {"hash": _hash_numpy(FakeCrossValidationReporter.y)} - assert item.cv_info == { - "n_splits": "5", - "random_state": "None", - "shuffle": "False", - } - assert isinstance(item.plots_bytes, dict) - assert isinstance(item.plots, dict) + item = CrossValidationReporterItem.factory(reporter) + + assert item.reporter_bytes == dumps(reporter) assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr + def test_reporter(self, mock_nowstr): + reporter = FakeCrossValidationReporter() + item1 = CrossValidationReporterItem.factory(reporter) + item2 = CrossValidationReporterItem( + reporter_bytes=dumps(reporter), + created_at=mock_nowstr, + updated_at=mock_nowstr, + ) + + assert item1.reporter == reporter + assert item2.reporter == reporter + def test_get_serializable_dict(self, monkeypatch, mock_nowstr): monkeypatch.setattr( - "skore.persistence.item.cross_validation_item.CrossValidationReporter", + "skore.persistence.item.cross_validation_reporter_item.CrossValidationReporter", FakeCrossValidationReporter, ) reporter = FakeCrossValidationReporter() - item = CrossValidationItem.factory(reporter) + item = CrossValidationReporterItem.factory(reporter) serializable = item.as_serializable_dict() assert serializable["updated_at"] == mock_nowstr