Skip to content

Commit

Permalink
feat: summary_plot returns the shap explanations
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Nov 2, 2023
1 parent f6e994e commit d57fa31
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""

from enum import Enum
from typing import Dict, NewType, Optional, Sequence, Union
from typing import Dict, List, NewType, Optional, Sequence, Union

import matplotlib.pyplot as plt
import pandas as pd
Expand Down Expand Up @@ -375,7 +375,7 @@ def summary_plot(
num_samples: Optional[int] = None,
plot_type: Optional[str] = "dot",
**kwargs,
):
) -> List[shap.Explanation]:
"""
Display a shap plot summary for each horizon and each component dimension of the target.
This method reuses the initial background data as foreground (potentially sampled) to give a general importance
Expand All @@ -395,6 +395,11 @@ def summary_plot(
for the sake of performance.
plot_type
Optionally, specify which of the shap library plot type to use. Can be one of ``'dot', 'bar', 'violin'``.
Returns
-------
shap_explanations
A list containing the raw Explanations of the visualized the horizons and components
"""

horizons, target_components = self._process_horizons_and_targets(
Expand All @@ -412,8 +417,10 @@ def summary_plot(
foreground_X_sampled, horizons, target_components
)

shap_explanations = []
for t in target_components:
for h in horizons:
shap_explanations.append(shaps_[h][t])
plt.title("Target: `{}` - Horizon: {}".format(t, "t+" + str(h)))
shap.summary_plot(
shaps_[h][t],
Expand All @@ -422,6 +429,8 @@ def summary_plot(
**kwargs,
)

return shap_explanations

def force_plot_from_ts(
self,
foreground_series: Optional[TimeSeries] = None,
Expand Down Expand Up @@ -613,7 +622,7 @@ def __init__(

def shap_explanations(
self,
foreground_X,
foreground_X: pd.DataFrame,
horizons: Optional[Sequence[int]] = None,
target_components: Optional[Sequence[str]] = None,
) -> Dict[int, Dict[str, shap.Explanation]]:
Expand Down Expand Up @@ -735,8 +744,8 @@ def _create_regression_model_shap_X(
target_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]],
n_samples=None,
train=False,
n_samples: Optional[int] = None,
train: bool = False,
) -> pd.DataFrame:
"""
Creates the shap format input for regression models.
Expand Down

0 comments on commit d57fa31

Please sign in to comment.