From 41e1177a73424409532f9b4627715888f3847316 Mon Sep 17 00:00:00 2001 From: Dennis Bader Date: Sun, 15 Sep 2024 12:41:49 +0200 Subject: [PATCH] =?UTF-8?q?fixes=20bug=20when=20plotting=20multivariate=20?= =?UTF-8?q?probabilistic=20series=20where=20the=20c=E2=80=A6=20(#2532)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixes bug when plotting multivariate probabilistic series where the confidence intervals had the same color as central series * add labels back * update changelog --- CHANGELOG.md | 1 + darts/timeseries.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 02718b8b2e..7dd876b32d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Fixed** +- Fixed bug when plotting a probabilistic multivariate series, where all confidence intervals (starting from 2nd component) had the same color as the median line. [#2532](https://github.com/unit8co/darts/pull/2532) by [Dennis Bader](https://github.com/dennisbader). - Fixed a bug when passing an empty array to `TimeSeries.prepend/append_values()` raised an error. [#2522](https://github.com/unit8co/darts/pull/2522) by [Alessio Pellegrini](https://github.com/AlessiopSymplectic) - Fixed a bug with `TimeSeries.prepend/append_values()`, where the name of the (time) index was lost. [#2522](https://github.com/unit8co/darts/pull/2522) by [Alessio Pellegrini](https://github.com/AlessiopSymplectic) - Fixed a bug when using `from_group_dataframe()` with a `time_col` of type integer, where the resulting time index was wrongly converted to a DatetimeIndex. [#2512](https://github.com/unit8co/darts/pull/2512) by [Alessio Pellegrini](https://github.com/AlessiopSymplectic) diff --git a/darts/timeseries.py b/darts/timeseries.py index f0c9fc006d..edadecb422 100644 --- a/darts/timeseries.py +++ b/darts/timeseries.py @@ -4191,8 +4191,6 @@ def plot( central_series = comp.mean(dim=DIMS[2]) alpha = kwargs["alpha"] if "alpha" in kwargs else None - if not self.is_deterministic: - kwargs["alpha"] = 1 if custom_labels: label_to_use = label[i] else: @@ -4204,15 +4202,18 @@ def plot( label_to_use = f"{label}_{comp_name}" kwargs["label"] = label_to_use + kwargs_central = deepcopy(kwargs) + if not self.is_deterministic: + kwargs_central["alpha"] = 1 if central_series.shape[0] > 1: - p = central_series.plot(*args, ax=ax, **kwargs) + p = central_series.plot(*args, ax=ax, **kwargs_central) # empty TimeSeries elif central_series.shape[0] == 0: p = ax.plot( [], [], *args, - **kwargs, + **kwargs_central, ) ax.set_xlabel(self.time_index.name) else: @@ -4221,7 +4222,7 @@ def plot( central_series.values[0], "o", *args, - **kwargs, + **kwargs_central, ) color_used = p[0].get_color() if default_formatting else None