Skip to content

Commit

Permalink
refactor the show_anomalies_from_scores()
Browse files Browse the repository at this point in the history
  • Loading branch information
cnhwl committed Dec 31, 2024
1 parent b05eb8a commit 33006d4
Showing 1 changed file with 155 additions and 192 deletions.
347 changes: 155 additions & 192 deletions darts/ad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,112 +469,23 @@ def show_anomalies_from_scores(

for i in range(series_width):
index_ax = i * nbr_plots

_plot_series(
_plot_series_and_anomalies(
series=series[series.components[i]],
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name="",
)

if pred_series[pred_series.components[i]] is not None:
_plot_series(
series=pred_series[pred_series.components[i]],
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name=pred_series.components[i] + " model_output",
)

axs[index_ax][0].set_title("")

if anomalies is not None or pred_scores is not None:
axs[index_ax][0].set_xlabel("")

axs[index_ax][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2
anomalies=anomalies[anomalies.components[i]]
if anomalies is not None
else None,
pred_series=pred_series[pred_series.components[i]]
if pred_series is not None
else None,
pred_scores=pred_scores,
window=window,
names_of_scorers=names_of_scorers,
metric=metric,
axs=axs,
index_ax=index_ax,
nbr_plots=nbr_plots,
)

if pred_scores is not None:
dict_input = {}

for idx, (score, w) in enumerate(zip(pred_scores, window)):
dict_input[idx] = {
"series_score": score,
"window": w,
"name_id": idx,
}

for index, elem in enumerate(
sorted(dict_input.items(), key=lambda x: x[1]["window"])
):
if index == 0:
current_window = elem[1]["window"]
index_ax = index_ax + 1

idx = elem[1]["name_id"]
w = elem[1]["window"]

if w != current_window:
current_window = w
index_ax = index_ax + 1

if metric is not None:
value = round(
eval_metric_from_scores(
anomalies=anomalies[anomalies.components[i]],
pred_scores=pred_scores[idx][
pred_scores[idx].components[i]
],
window=w,
metric=metric,
),
3,
)
else:
value = None

if names_of_scorers is not None:
label = (
names_of_scorers[idx] + [f" ({value})", ""][value is None]
)
else:
label = f"score_{str(idx)}" + [f" ({value})", ""][value is None]

_plot_series(
series=elem[1]["series_score"][
elem[1]["series_score"].components[i]
],
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name=label,
)

axs[index_ax][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2
)
axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left")
axs[index_ax][0].set_title("")
axs[index_ax][0].set_xlabel("")

if anomalies is not None:
_plot_series(
series=anomalies[anomalies.components[i]],
ax_id=axs[index_ax + 1][0],
linewidth=1,
label_name=anomalies.components[i],
color="red",
)

axs[index_ax + 1][0].set_title("")
axs[index_ax + 1][0].set_ylim([-0.1, 1.1])
axs[index_ax + 1][0].set_yticks([0, 1])
axs[index_ax + 1][0].set_yticklabels(["no", "yes"])
axs[index_ax + 1][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2
)
else:
axs[index_ax][0].set_xlabel("timestamp")

fig.suptitle(title)
else:
fig, axs = plt.subplots(
Expand All @@ -586,97 +497,19 @@ def show_anomalies_from_scores(
)

index_ax = 0

_plot_series(
series=series, ax_id=axs[index_ax][0], linewidth=0.5, label_name=""
_plot_series_and_anomalies(
series=series,
anomalies=anomalies,
pred_series=pred_series,
pred_scores=pred_scores,
window=window,
names_of_scorers=names_of_scorers,
metric=metric,
axs=axs,
index_ax=index_ax,
nbr_plots=nbr_plots,
)

if pred_series is not None:
_plot_series(
series=pred_series,
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name="model output",
)

axs[index_ax][0].set_title("")

if anomalies is not None or pred_scores is not None:
axs[index_ax][0].set_xlabel("")

axs[index_ax][0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2)

if pred_scores is not None:
dict_input = {}

for idx, (score, w) in enumerate(zip(pred_scores, window)):
dict_input[idx] = {"series_score": score, "window": w, "name_id": idx}

for index, elem in enumerate(
sorted(dict_input.items(), key=lambda x: x[1]["window"])
):
if index == 0:
current_window = elem[1]["window"]
index_ax = index_ax + 1

idx = elem[1]["name_id"]
w = elem[1]["window"]

if w != current_window:
current_window = w
index_ax = index_ax + 1

if metric is not None:
value = round(
eval_metric_from_scores(
anomalies=anomalies,
pred_scores=pred_scores[idx],
window=w,
metric=metric,
),
3,
)
else:
value = None

if names_of_scorers is not None:
label = names_of_scorers[idx] + [f" ({value})", ""][value is None]
else:
label = f"score_{str(idx)}" + [f" ({value})", ""][value is None]

_plot_series(
series=elem[1]["series_score"],
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name=label,
)

axs[index_ax][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2
)
axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left")
axs[index_ax][0].set_title("")
axs[index_ax][0].set_xlabel("")

if anomalies is not None:
_plot_series(
series=anomalies,
ax_id=axs[index_ax + 1][0],
linewidth=1,
label_name="anomalies",
color="red",
)

axs[index_ax + 1][0].set_title("")
axs[index_ax + 1][0].set_ylim([-0.1, 1.1])
axs[index_ax + 1][0].set_yticks([0, 1])
axs[index_ax + 1][0].set_yticklabels(["no", "yes"])
axs[index_ax + 1][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2
)
else:
axs[index_ax][0].set_xlabel("timestamp")

fig.suptitle(title)


Expand Down Expand Up @@ -937,3 +770,133 @@ def _assert_fit_called(fit_called: bool, name: str):
),
logger=logger,
)


def _plot_series_and_anomalies(
series: TimeSeries,
anomalies: TimeSeries,
pred_series: TimeSeries,
pred_scores: Sequence[TimeSeries],
window: Sequence[int],
names_of_scorers: Sequence[str],
metric: str,
axs: plt.Axes,
index_ax: int,
nbr_plots: int,
):
"""Helper function to plot series and anomalies.
Parameters
----------
series
The actual series to visualize anomalies from.
anomalies
The ground truth of the anomalies (1 if it is an anomaly and 0 if not).
pred_series
Output of the model given as input the `series` (can be stochastic).
pred_scores
Output of the scorers given the output of the model and `series`.
window
Window parameter for each anomaly scores.
names_of_scorers
Name of the scores.
metric
The name of the metric function to use.
axs
The axes to plot on.
index_ax
The index of the current axis.
nbr_plots
The number of plots.
"""
_plot_series(series=series, ax_id=axs[index_ax][0], linewidth=0.5, label_name="")

if pred_series is not None:
_plot_series(
series=pred_series,
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name="model output",
)

axs[index_ax][0].set_title("")

if anomalies is not None or pred_scores is not None:
axs[index_ax][0].set_xlabel("")

axs[index_ax][0].legend(loc="upper center", bbox_to_anchor=(0.5, 1.1), ncol=2)

if pred_scores is not None:
dict_input = {}

for idx, (score, w) in enumerate(zip(pred_scores, window)):
dict_input[idx] = {"series_score": score, "window": w, "name_id": idx}

for index, elem in enumerate(
sorted(dict_input.items(), key=lambda x: x[1]["window"])
):
if index == 0:
current_window = elem[1]["window"]
index_ax = index_ax + 1

idx = elem[1]["name_id"]
w = elem[1]["window"]

if w != current_window:
current_window = w
index_ax = index_ax + 1

if metric is not None:
value = round(
eval_metric_from_scores(
anomalies=anomalies,
pred_scores=pred_scores[idx][
pred_scores[idx].components[index_ax // nbr_plots]
],
window=w,
metric=metric,
),
3,
)
else:
value = None

if names_of_scorers is not None:
label = names_of_scorers[idx] + [f" ({value})", ""][value is None]
else:
label = f"score_{str(idx)}" + [f" ({value})", ""][value is None]

_plot_series(
series=elem[1]["series_score"][
elem[1]["series_score"].components[index_ax // nbr_plots]
],
ax_id=axs[index_ax][0],
linewidth=0.5,
label_name=label,
)

axs[index_ax][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.19), ncol=2
)
axs[index_ax][0].set_title(f"Window: {str(w)}", loc="left")
axs[index_ax][0].set_title("")
axs[index_ax][0].set_xlabel("")

if anomalies is not None:
_plot_series(
series=anomalies,
ax_id=axs[index_ax + 1][0],
linewidth=1,
label_name="anomalies",
color="red",
)

axs[index_ax + 1][0].set_title("")
axs[index_ax + 1][0].set_ylim([-0.1, 1.1])
axs[index_ax + 1][0].set_yticks([0, 1])
axs[index_ax + 1][0].set_yticklabels(["no", "yes"])
axs[index_ax + 1][0].legend(
loc="upper center", bbox_to_anchor=(0.5, 1.2), ncol=2
)
else:
axs[index_ax][0].set_xlabel("timestamp")

0 comments on commit 33006d4

Please sign in to comment.