diff --git a/darts/explainability/shap_explainer.py b/darts/explainability/shap_explainer.py index 143ea0d8b9..b45fb275cf 100644 --- a/darts/explainability/shap_explainer.py +++ b/darts/explainability/shap_explainer.py @@ -25,7 +25,7 @@ """ from enum import Enum -from typing import Dict, NewType, Optional, Sequence, Union +from typing import Dict, List, NewType, Optional, Sequence, Union import matplotlib.pyplot as plt import pandas as pd @@ -375,7 +375,7 @@ def summary_plot( num_samples: Optional[int] = None, plot_type: Optional[str] = "dot", **kwargs, - ): + ) -> List[shap.Explanation]: """ Display a shap plot summary for each horizon and each component dimension of the target. This method reuses the initial background data as foreground (potentially sampled) to give a general importance @@ -395,6 +395,11 @@ def summary_plot( for the sake of performance. plot_type Optionally, specify which of the shap library plot type to use. Can be one of ``'dot', 'bar', 'violin'``. + + Returns + ------- + shap_explanations + A list containing the raw Explanations of the visualized the horizons and components """ horizons, target_components = self._process_horizons_and_targets( @@ -412,8 +417,10 @@ def summary_plot( foreground_X_sampled, horizons, target_components ) + shap_explanations = [] for t in target_components: for h in horizons: + shap_explanations.append(shaps_[h][t]) plt.title("Target: `{}` - Horizon: {}".format(t, "t+" + str(h))) shap.summary_plot( shaps_[h][t], @@ -422,6 +429,8 @@ def summary_plot( **kwargs, ) + return shap_explanations + def force_plot_from_ts( self, foreground_series: Optional[TimeSeries] = None, @@ -613,7 +622,7 @@ def __init__( def shap_explanations( self, - foreground_X, + foreground_X: pd.DataFrame, horizons: Optional[Sequence[int]] = None, target_components: Optional[Sequence[str]] = None, ) -> Dict[int, Dict[str, shap.Explanation]]: @@ -735,8 +744,8 @@ def _create_regression_model_shap_X( target_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]], past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]], future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]], - n_samples=None, - train=False, + n_samples: Optional[int] = None, + train: bool = False, ) -> pd.DataFrame: """ Creates the shap format input for regression models.