Skip to content

Commit

Permalink
Added PlotMLResults Task to create CM (columnflow#341)
Browse files Browse the repository at this point in the history
* Add MergeMLEvaluation task.

* Add MergeMLEvaluation task.

* starting point for PlotMLEvaluation task

* Linting.

* Added the MergeMLEvaluation task and Wrapper

* fixed linting issues

* Merged with upstream master

* name change + docs

* working PlotMlResults

* working cm but buggy plot ratio

* commit for further testing

* working plotting cm

* fix color bar size bug

* added sample weights + better docs

* Add ample_weights  and save multiple plots in pdf

* fixed trailing comma

* linting fixes

* fixed imports

* removed unneeded None as type from weights

* initial fixes

* apply review comments :)

* did not see some linting errors

* fixed issues with uncert and column normalization

* fixed wrong typing

* adapted to upcoming ROC curve

* trailing comma

* fixed

---------

Co-authored-by: Marcel R <[email protected]>
Co-authored-by: Mathis Frahm <[email protected]>
Co-authored-by: Stella S <[email protected]>
Co-authored-by: Marcel R <[email protected]>
  • Loading branch information
5 people authored Dec 7, 2023
1 parent 2186fc1 commit befe144
Show file tree
Hide file tree
Showing 2 changed files with 564 additions and 2 deletions.
310 changes: 310 additions & 0 deletions columnflow/plotting/plot_ml_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# coding: utf-8

"""
Useful plot functions for ML Evaluation
"""

from __future__ import annotations

import re

from columnflow.types import Sequence
from columnflow.util import maybe_import

ak = maybe_import("awkward")
od = maybe_import("order")
np = maybe_import("numpy")
sci = maybe_import("scinum")
plt = maybe_import("matplotlib.pyplot")
hep = maybe_import("mplhep")
colors = maybe_import("matplotlib.colors")

# define a CF custom color maps
cf_colors = {
"cf_green_cmap": colors.ListedColormap([
"#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927",
"#305126", "#325A25", "#356224", "#386B22", "#3B7520", "#3F7F1E", "#43891B",
"#479418", "#4C9F14", "#52AA10", "#58B60C", "#5FC207", "#67cf02",
]),
"cf_ygb_cmap": colors.ListedColormap([
"#003675", "#005B83", "#008490", "#009A83", "#00A368", "#00AC49", "#00B428",
"#00BC06", "#0CC300", "#39C900", "#67cf02", "#72DB02", "#7EE605", "#8DF207",
"#9CFD09", "#AEFF0B", "#C1FF0E", "#D5FF10", "#EBFF12", "#FFFF14",
]),
"cf_cmap": colors.ListedColormap([
"#002C9C", "#00419F", "#0056A2", "#006BA4", "#0081A7", "#0098AA", "#00ADAB",
"#00B099", "#00B287", "#00B574", "#00B860", "#00BB4C", "#00BD38", "#00C023",
"#00C20D", "#06C500", "#1EC800", "#36CA00", "#4ECD01", "#67cf02",
]),
"viridis": colors.ListedColormap([
"#263DA8", "#1652CC", "#1063DB", "#1171D8", "#1380D5", "#0E8ED0", "#089DCC",
"#0DA7C2", "#1DAFB3", "#2DB7A3", "#52BA91", "#73BD80", "#94BE71", "#B2BC65",
"#D0BA59", "#E1BF4A", "#F4C53A", "#FCD12B", "#FAE61C", "#F9F90E",
]),
}


def create_sample_weights(
sample_weights: Sequence[float] | bool,
events: dict,
true_labels: Sequence[str],
) -> dict:
"""
Helper function to create the sample weights for the events, if needed.
:param sample_weights: sample weights of the events. If an explicit array is not given,
the weights are calculated based on the number of events.
:param events: dictionary with the true labels as keys and the model output of the events as values.
:param true_labels: true labels of the events
:return: sample weights of the events
:raises ValueError: If both predictions and labels have mismatched shapes, or
if *weights* is not *None* and its shape doesn't match the predictions length.
"""
if not sample_weights:
return {label: 1 for label in true_labels}
if not (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))):
raise ValueError(f"Shape of sample_weights {len(sample_weights)} does not match "
f"shape of predictions {len(true_labels)}")
if isinstance(sample_weights, bool):
size = {label: len(event) for label, event in events.items()}
mean = np.mean(list(size.values()))
return {label: mean / length for label, length in size.items()}
return dict(zip(true_labels, sample_weights))


def plot_cm(
events: dict,
config_inst: od.Config,
category_inst: od.Category,
sample_weights: list | bool = False,
normalization: str = "row",
skip_uncertainties: bool = False,
x_labels: list[str] | None = None,
y_labels: list[str] | None = None,
cms_rlabel: str = "",
cms_llabel: str = "private work",
*args,
**kwargs,
) -> tuple[list[plt.Figure], np.ndarray]:
"""
Generates the figure of the confusion matrix given the output of the nodes
and an array of true labels. The Cronfusion matrix can also be weighted.
:param events: Dictionary with the true labels as keys and the model output of the events as values.
:param config_inst: The used config instance, for which the plot is created.
:param category_inst: The used category instance, for which the plot is created.
:param sample_weights: Sample weights applied to the confusion matrix the events.
If an explicit array is not given, the weights are calculated based on the number of events when set to *True*.
:param normalization: The type of normalization of the confusion matrix.
This parameter takes 'row', 'col' or '' (empty string) as argument.
If not provided, the matrix is row normalized.
:param skip_uncertainties: If true, no uncertainty of the cells will be shown in the plot.
:param x_labels: The labels for the x-axis. If not provided, the labels will be 'out<i>'
:param y_labels: The labels for the y-axis. If not provided, the dataset names are used.
:param *args: Additional arguments to pass to the function.
:param **kwargs: Additional keyword arguments to pass to the function.
:return: Returns the resulting plot and the confusion matrix.
:raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights*
is not *None* and its shape doesn't match *predictions*.
:raises AssertionError: If *normalization* is not one of *None*, "row", "column".
"""
# defining some useful properties and output shapes
true_labels = list(events.keys())
pred_labels = [s.removeprefix("score_") for s in list(events.values())[0].fields]
return_type = np.float32 if sample_weights else np.int32
mat_shape = (len(true_labels), len(pred_labels))

def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray:
result = np.zeros(shape=mat_shape, dtype=return_type)
counts = np.zeros(shape=mat_shape, dtype=return_type)
sample_weights = create_sample_weights(sample_weights, events, true_labels)

# looping over the datasets
for ind, (dataset, pred) in enumerate(events.items()):
# remove awkward structure to use the numpy logic
pred = ak.to_numpy(pred)
pred = pred.view(float).reshape((pred.size, len(pred_labels)))

# create predictions of the model output
pred = np.argmax(pred, axis=-1)

for index, count in zip(*np.unique(pred, return_counts=True)):
result[ind, index] += count * sample_weights[dataset]
counts[ind, index] += count

if not skip_uncertainties:
vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0)))
result = vecNumber(result, counts)

# normalize Matrix if needed
if normalization:
valid = {"row": 1, "column": 0}
if normalization not in valid.keys():
raise ValueError(
f"'{normalization}' is not a valid argument for normalization. If given, normalization "
"should only take 'row' or 'column'",
)

row_sums = result.sum(axis=valid.get(normalization))
result = result / row_sums[:, np.newaxis] if valid.get(normalization) else result / row_sums

return result

def plot_confusion_matrix(
cm: np.ndarray,
colormap: str = "cf_cmap",
cmap_label: str = "Accuracy",
digits: int = 3,
x_labels: list[str] | None = None,
y_labels: list[str] | None = None,
cms_rlabel: str = "",
cms_llabel: str = "private work",
*args,
**kwargs,
) -> plt.figure:
"""
Plots a confusion matrix.
:param cm: The confusion matrix to plot.
:param title: The title of the plot, displayed in the top right corner.
:param colormap: The name of the colormap to use. Can be selected from the following:
"cf_cmap", "cf_green_cmap", "cf_ygb_cmap", "viridis".
:param cmap_label: The label of the colorbar.
:param digits: The number of digits to display for each value in the matrix.
:param x_labels: The labels for the x-axis. If not provided, the labels will be "out<i>"
:param y_labels: The labels for the y-axis. If not provided, the dataset names are used.
:param *args: Additional arguments to pass to the function.
:param **kwargs: Additional keyword arguments to pass to the function.
:return: The resulting plot.
"""
from mpl_toolkits.axes_grid1 import make_axes_locatable

def calculate_font_size():
# get cell width
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width, height = fig.dpi * bbox.width, fig.dpi * bbox.height

# size of each cell in pixels
cell_width = width / n_classes
cell_height = height / n_processes

# calculate the font size based on the cell size to ensure font is not too large
font_size = min(cell_width, cell_height) / 10
font_size = max(min(font_size, 18), 8)

return font_size

def get_errors(matrix):
"""
Useful for seperating the error from the data
"""
if matrix.dtype.name == "object":
get_errors_vec = np.vectorize(lambda x: x.get(sci.UP, unc=True))
return get_errors_vec(matrix)
return np.zeros_like(matrix)

def value_text(i, j):
"""
Format the inputs as 'Number +- Uncertainty'
"""
def fmt(v):
s = "{{:.{}f}}".format(digits).format(v)
return s if re.sub(r"(0|\.)", "", s) else ("<" + s[:-1] + "1")
if skip_uncertainties:
return fmt(values[i][j])
return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j])))

# create the plot
plt.style.use(hep.style.CMS)
fig, ax = plt.subplots(dpi=300)

# some useful variables and functions
n_processes = cm.shape[0]
n_classes = cm.shape[1]
cmap = cf_colors.get(colormap, cf_colors["cf_cmap"])
x_labels = x_labels if x_labels else [f"out{i}" for i in range(n_classes)]
y_labels = y_labels if y_labels else true_labels
font_ax = 20
font_label = 20
font_text = calculate_font_size()

# get values and (if available) their uncertenties
values = cm.astype(np.float32)
uncs = get_errors(cm)

# remove Major ticks and edit minor ticks
minor_tick_length = max(int(120 / n_classes), 12) / 2
minor_tick_width = max(6 / n_classes, 0.6)
xtick_marks = np.arange(n_classes)
ytick_marks = np.arange(n_processes)

# plot the data
im = ax.imshow(values, interpolation="nearest", cmap=cmap)

# plot settings
thresh = values.max() / 2.
ax.tick_params(axis="both", which="major", bottom=False, top=False, left=False, right=False)
ax.tick_params(
axis="both",
which="minor",
bottom=True,
top=True,
left=True,
right=True,
length=minor_tick_length,
width=minor_tick_width,
)
ax.set_xticks(xtick_marks + 0.5, minor=True)
ax.set_yticks(ytick_marks + 0.5, minor=True)
ax.set_xticks(xtick_marks)
ax.set_xticklabels(x_labels, rotation=0, fontsize=font_label)
ax.set_yticks(ytick_marks)
ax.set_yticklabels(y_labels, fontsize=font_label)
ax.set_xlabel("Predicted process", loc="right", labelpad=10, fontsize=font_ax)
ax.set_ylabel("True process", loc="top", labelpad=15, fontsize=font_ax)

# adding a color bar on a new axis and adjusting its values
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.10)
colorbar = fig.colorbar(im, cax=cax)
colorbar.set_label(label=cmap_label, fontsize=font_ax)
colorbar.ax.tick_params(labelsize=font_ax - 5)
im.set_clim(0, max(1, values.max()))

# add Matrix Elemtns
for i in range(values.shape[0]):
for j in range(values.shape[1]):
ax.text(
j,
i,
value_text(i, j),
fontdict={"size": font_text},
horizontalalignment="center",
verticalalignment="center",
color="white" if values[i, j] < thresh else "black",
)

# final touches
hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel)
plt.tight_layout()

return fig

cm = get_conf_matrix(sample_weights, *args, **kwargs)
print("Confusion matrix calculated!")

fig = plot_confusion_matrix(
cm,
x_labels=x_labels,
y_labels=y_labels,
cms_llabel=cms_llabel,
cms_rlabel=cms_rlabel,
*args,
**kwargs,
)
print("Confusion matrix plotted!")

return [fig], cm
Loading

0 comments on commit befe144

Please sign in to comment.