diff --git a/examples/getting_started/plot_quick_start.py b/examples/getting_started/plot_quick_start.py index 421ca0328..040a738e8 100644 --- a/examples/getting_started/plot_quick_start.py +++ b/examples/getting_started/plot_quick_start.py @@ -42,9 +42,6 @@ # %% reporter.plots.scores -# %% -reporter.plots.timing_normalized - # %% # Finally, from your shell (in the same directory), start the UI: # diff --git a/skore/src/skore/item/cross_validation_item.py b/skore/src/skore/item/cross_validation_item.py index e4873f3a0..205949744 100644 --- a/skore/src/skore/item/cross_validation_item.py +++ b/skore/src/skore/item/cross_validation_item.py @@ -344,7 +344,6 @@ def factory(cls, reporter: CrossValidationReporter) -> CrossValidationItem: humanized_plot_names = { "scores": "Scores", "timing": "Timings", - "timing_normalized": "Normalized timings", } plots_bytes = { humanized_plot_names[plot_name]: ( diff --git a/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py b/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py index 8dc52ab16..47d9b7d09 100644 --- a/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py +++ b/skore/src/skore/sklearn/cross_validation/cross_validation_reporter.py @@ -16,7 +16,6 @@ _strip_cv_results_scores, ) from .plots.compare_scores_plot import plot_cross_validation_compare_scores -from .plots.timing_normalized_plot import plot_cross_validation_timing_normalized from .plots.timing_plot import plot_cross_validation_timing @@ -26,7 +25,6 @@ class CrossValidationPlots: scores: plotly.graph_objects.Figure timing: plotly.graph_objects.Figure - timing_normalized: plotly.graph_objects.Figure class CrossValidationReporter: @@ -189,5 +187,4 @@ def plots(self) -> CrossValidationPlots: return CrossValidationPlots( scores=plot_cross_validation_compare_scores(self._cv_results), timing=plot_cross_validation_timing(self._cv_results), - timing_normalized=plot_cross_validation_timing_normalized(self._cv_results), ) diff --git a/skore/src/skore/sklearn/cross_validation/plots/timing_normalized_plot.py b/skore/src/skore/sklearn/cross_validation/plots/timing_normalized_plot.py deleted file mode 100644 index 767fee460..000000000 --- a/skore/src/skore/sklearn/cross_validation/plots/timing_normalized_plot.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Plot cross-validation normalized timing results.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from numpy import linspace - -if TYPE_CHECKING: - import plotly.graph_objects - - -def plot_cross_validation_timing_normalized( - cv_results: dict, -) -> plotly.graph_objects.Figure: - """Plot the normalized timing results of a cross-validation run. - - Each timing result is normalized by the number of data points. - - Parameters - ---------- - cv_results : dict - The output of scikit-learn's cross_validate function. - - Returns - ------- - plotly.graph_objects.Figure - A plot of the normalized time-related cross-validation results. - """ - from datetime import timedelta - - import pandas - import plotly - import plotly.graph_objects as go - - _cv_results = cv_results.copy() - - # Remove irrelevant keys - to_remove = [ - key - for key in _cv_results - if key not in ["fit_time_per_data_point", "score_time_per_data_point"] - ] - for key in to_remove: - _cv_results.pop(key, None) - - df = pandas.DataFrame(_cv_results) - - dict_labels = { - "fit_time_per_data_point": "fit_time_per_data_point (seconds)", - "score_time_per_data_point": "score_time_per_data_point (seconds)", - } - - fig = go.Figure() - - for col_i, col_name in enumerate(df.columns): - metric_name = dict_labels.get(col_name, col_name) - bar_color = plotly.colors.qualitative.Plotly[ - col_i % len(plotly.colors.qualitative.Plotly) - ] - bar_x = linspace(min(df.index) - 0.5, max(df.index) + 0.5, num=10) - - common_kwargs = dict( - visible=True if col_i == 0 else "legendonly", - legendgroup=f"group{col_i}", - ) - - # Calculate statistics - avg_value = df[col_name].mean() - std_value = df[col_name].std() - - # Add all traces at once - fig.add_traces( - [ - # Bar trace - go.Bar( - x=df.index, - y=df[col_name].values, - name=metric_name, - marker_color=bar_color, - showlegend=True, - hovertemplate=( - "%{customdata}" f"{col_name} (timedelta)" - ), - customdata=[str(timedelta(seconds=x)) for x in df[col_name].values], - **common_kwargs, - ), - # Mean line - go.Scatter( - x=bar_x, - y=[avg_value] * 10, - name=f"Average {metric_name}", - line=dict(dash="dash", color=bar_color), - showlegend=False, - mode="lines", - hovertemplate="%{customdata}", - customdata=[str(timedelta(seconds=avg_value))] * 10, - **common_kwargs, - ), - # +1 std line - go.Scatter( - x=bar_x, - y=[avg_value + std_value] * 10, - name=f"Average + 1 std. dev. {metric_name}", - line=dict(dash="dot", color=bar_color), - showlegend=False, - mode="lines", - hovertemplate="%{customdata}", - customdata=[str(timedelta(seconds=avg_value + std_value))] * 10, - **common_kwargs, - ), - # -1 std line - go.Scatter( - x=bar_x, - y=[avg_value - std_value] * 10, - name=f"Average - 1 std. dev. {metric_name}", - line=dict(dash="dot", color=bar_color), - showlegend=False, - mode="lines", - hovertemplate="%{customdata}", - customdata=[str(timedelta(seconds=avg_value + std_value))] * 10, - **common_kwargs, - ), - ] - ) - - fig.update_xaxes(tickmode="linear", dtick=1, title_text="Split index") - fig.update_yaxes(title_text="Value") - fig.update_layout( - title_text="Normalized time-related cross-validation results for each split" - ) - - return fig diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_item.py index bd9da08d6..3c3e557bf 100644 --- a/skore/tests/unit/item/test_cross_validation_item.py +++ b/skore/tests/unit/item/test_cross_validation_item.py @@ -37,7 +37,6 @@ class FakeCrossValidationReporter(CrossValidationReporter): plots = CrossValidationPlots( scores=plotly.graph_objects.Figure(), timing=plotly.graph_objects.Figure(), - timing_normalized=plotly.graph_objects.Figure(), ) cv = StratifiedKFold(n_splits=5) @@ -59,7 +58,6 @@ class FakeCrossValidationReporterNoGetParams(CrossValidationReporter): plots = CrossValidationPlots( scores=plotly.graph_objects.Figure(), timing=plotly.graph_objects.Figure(), - timing_normalized=plotly.graph_objects.Figure(), ) cv = StratifiedKFold(n_splits=5)