Skip to content

Commit

Permalink
Merge branch 'PlotMLResults_Task' of github.com:haddadanas/columnflow…
Browse files Browse the repository at this point in the history
… into ROC_PlotTask
  • Loading branch information
haddadanas committed Dec 7, 2023
2 parents a364bc8 + e8f9859 commit c582628
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 55 deletions.
73 changes: 39 additions & 34 deletions columnflow/plotting/plot_ml_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import re

from columnflow.types import Sequence, Dict, List, Tuple
from columnflow.types import Sequence
from columnflow.util import maybe_import

ak = maybe_import("awkward")
Expand All @@ -19,7 +19,7 @@
hep = maybe_import("mplhep")
colors = maybe_import("matplotlib.colors")

# Define a CF custom color maps
# define a CF custom color maps
cf_colors = {
"cf_green_cmap": colors.ListedColormap([
"#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927",
Expand Down Expand Up @@ -85,23 +85,26 @@ def plot_cm(
cms_llabel: str = "private work",
*args,
**kwargs,
) -> Tuple[List[plt.Figure], np.ndarray]:
""" Generates the figure of the confusion matrix given the output of the nodes
) -> tuple[list[plt.Figure], np.ndarray]:
"""
Generates the figure of the confusion matrix given the output of the nodes
and an array of true labels. The Cronfusion matrix can also be weighted.
:param events: dictionary with the true labels as keys and the model output of the events as values.
:param config_inst: used configuration for the plot.
:param category_inst: used category instance, for which the plot is created.
:param sample_weights: sample weights of the events. If an explicit array is not given, the weights are
calculated based on the number of events.
:param normalization: type of normalization of the confusion matrix. If not provided, the matrix is row normalized.
:param events: Dictionary with the true labels as keys and the model output of the events as values.
:param config_inst: The used config instance, for which the plot is created.
:param category_inst: The used category instance, for which the plot is created.
:param sample_weights: Sample weights applied to the confusion matrix the events.
If an explicit array is not given, the weights are calculated based on the number of events when set to *True*.
:param normalization: The type of normalization of the confusion matrix.
This parameter takes 'row', 'col' or '' (empty string) as argument.
If not provided, the matrix is row normalized.
:param skip_uncertainties: If true, no uncertainty of the cells will be shown in the plot.
:param x_labels: labels for the x-axis.
:param y_labels: labels for the y-axis.
:param x_labels: The labels for the x-axis. If not provided, the labels will be 'out<i>'
:param y_labels: The labels for the y-axis. If not provided, the dataset names are used.
:param *args: Additional arguments to pass to the function.
:param **kwargs: Additional keyword arguments to pass to the function.
:return: The resulting plot and the confusion matrix.
:return: Returns the resulting plot and the confusion matrix.
:raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights*
is not *None* and its shape doesn't match *predictions*.
Expand Down Expand Up @@ -135,13 +138,13 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray:
vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0)))
result = vecNumber(result, counts)

# Normalize Matrix if needed
# normalize Matrix if needed
if normalization:
valid = {"row": 1, "column": 0}
if normalization not in valid.keys():
raise ValueError(
f"\"{normalization}\" is not a valid argument for normalization. If given, normalization "
"should only take \"row\" or \"column\"",
f"'{normalization}' is not a valid argument for normalization. If given, normalization "
"should only take 'row' or 'column'",
)

row_sums = result.sum(axis=valid.get(normalization))
Expand Down Expand Up @@ -180,15 +183,15 @@ def plot_confusion_matrix(
from mpl_toolkits.axes_grid1 import make_axes_locatable

def calculate_font_size():
# Get cell width
# get cell width
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width, height = fig.dpi * bbox.width, fig.dpi * bbox.height

# Size of each cell in pixels
# size of each cell in pixels
cell_width = width / n_classes
cell_height = height / n_processes

# Calculate the font size based on the cell size to ensure font is not too large
# calculate the font size based on the cell size to ensure font is not too large
font_size = min(cell_width, cell_height) / 10
font_size = max(min(font_size, 18), 8)

Expand Down Expand Up @@ -218,7 +221,7 @@ def fmt(v):
plt.style.use(hep.style.CMS)
fig, ax = plt.subplots(dpi=300)

# Some useful variables and functions
# some useful variables and functions
n_processes = cm.shape[0]
n_classes = cm.shape[1]
cmap = cf_colors.get(colormap, cf_colors["cf_cmap"])
Expand All @@ -228,11 +231,11 @@ def fmt(v):
font_label = 20
font_text = calculate_font_size()

# Get values and (if available) their uncertenties
# get values and (if available) their uncertenties
values = cm.astype(np.float32)
uncs = get_errors(cm)

# Remove Major ticks and edit minor ticks
# remove Major ticks and edit minor ticks
minor_tick_length = max(int(120 / n_classes), 12) / 2
minor_tick_width = max(6 / n_classes, 0.6)
xtick_marks = np.arange(n_classes)
Expand All @@ -241,7 +244,7 @@ def fmt(v):
# plot the data
im = ax.imshow(values, interpolation="nearest", cmap=cmap)

# Plot settings
# plot settings
thresh = values.max() / 2.
ax.tick_params(axis="both", which="major", bottom=False, top=False, left=False, right=False)
ax.tick_params(
Expand Down Expand Up @@ -271,7 +274,7 @@ def fmt(v):
colorbar.ax.tick_params(labelsize=font_ax - 5)
im.set_clim(0, max(1, values.max()))

# Add Matrix Elemtns
# add Matrix Elemtns
for i in range(values.shape[0]):
for j in range(values.shape[1]):
ax.text(
Expand Down Expand Up @@ -319,7 +322,7 @@ def plot_roc(
cms_llabel: str = "private work",
*args,
**kwargs,
) -> tuple[List[plt.Figure], dict]:
) -> tuple[list[plt.Figure], dict]:
"""
Generates the figure of the ROC curve given the output of the nodes
and an array of true labels. The ROC curve can also be weighted.
Expand All @@ -331,7 +334,7 @@ def plot_roc(
calculated based on the number of events.
:param n_thresholds: number of thresholds to use for the ROC curve.
:param skip_discriminators: list of discriminators to skip.
:param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is "OvR".
:param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'.
:param cms_rlabel: right label of the CMS label.
:param cms_llabel: left label of the CMS label.
:param *args: Additional arguments to pass to the function.
Expand All @@ -341,7 +344,7 @@ def plot_roc(
:raises ValueError: If both predictions and labels have mismatched shapes, or if *weights*
is not *None* and its shape doesn't match *predictions*.
:raises ValueError: If *normalization* is not one of *None*, "row", "column".
:raises ValueError: If *normalization* is not one of *None*, 'row', 'column'.
"""
# defining some useful properties and output shapes
thresholds = np.linspace(0, 1, n_thresholds)
Expand All @@ -350,10 +353,12 @@ def plot_roc(
figs = []

if evaluation_type not in ["OvO", "OvR"]:
raise ValueError("Illeagal Argument! Evaluation Type can only be choosen as \"OvO\" (One vs One) \
or \"OvR\" (One vs Rest)")
raise ValueError(
"Illeagal Argument! Evaluation Type can only be choosen as 'OvO' (One vs One)"
"or 'OvR' (One vs Rest)"
)

def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> Dict[str, Dict[str, np.ndarray]]:
def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> dict[str, dict[str, np.ndarray]]:
"""
Helper function to create histograms for the different discriminators and classes.
"""
Expand All @@ -370,12 +375,12 @@ def binary_roc_data(
negativ_hist: np.ndarray,
*args,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute binary Receiver operating characteristic (ROC) values.
Used as a helper function for the multi-dimensional ROC curve
"""
# Calculate the different rates
# calculate the different rates
fn = np.cumsum(positiv_hist)
tn = np.cumsum(negativ_hist)
tp = fn[-1] - fn
Expand All @@ -391,7 +396,7 @@ def roc_curve_data(
histograms: dict,
*args,
**kwargs,
) -> Dict[str, Dict[str, np.ndarray]]:
) -> dict[str, dict[str, np.ndarray]]:
"""
Compute Receiver operating characteristic (ROC) values for a multi-dimensional output.
"""
Expand All @@ -403,7 +408,7 @@ def roc_curve_data(
if disc in skip_discriminators:
continue

# Choose the evaluation type
# choose the evaluation type
if (evaluation_type == "OvO"):
for pos_cls, pos_hist in histograms[disc].items():
for neg_cls, neg_hist in histograms[disc].items():
Expand Down
36 changes: 18 additions & 18 deletions columnflow/tasks/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"""
from __future__ import annotations

from collections import OrderedDict

from collections import defaultdict
from collections import OrderedDict, defaultdict

import law
import luigi
Expand All @@ -32,7 +30,6 @@
from columnflow.tasks.production import ProduceColumns
from columnflow.util import dev_sandbox, safe_div, DotDict
from columnflow.util import maybe_import
from columnflow.types import Dict, List


ak = maybe_import("awkward")
Expand Down Expand Up @@ -947,9 +944,9 @@ class PlotMLResultsBase(
)

skip_processes = law.CSVParameter(
default=("",),
default=(),
description="comma seperated list of process names to skip; these processes will not be included in the plots. "
"default: ('',)",
"default: ()",
brace_expand=True,
)

Expand Down Expand Up @@ -1005,22 +1002,22 @@ def workflow_requires(self: PlotMLResultsBase, only_super: bool = False):

return reqs

def output(self: PlotMLResultsBase) -> Dict[str, List]:
def output(self: PlotMLResultsBase) -> dict[str, list]:
b = self.branch_data
return {"plots": [
self.target(name)
for name in self.get_plot_names(f"plot__proc_{self.processes_repr}__cat_{b.category}")
]}

def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]:
def prepare_inputs(self: PlotMLResultsBase) -> dict[str, ak.Array]:
"""
prepare the inputs for the plot function, based on the given configuration and category.
:raises NotImplementedError: This error is raised if a given dataset contains more than one process.
:raises ValueError: This error is raised if ``plot_sub_processes`` is used without providing the
``process_ids`` column in the data
:return: Dict[str, ak.Array]: A dictionary with the dataset names as keys and
:return: dict[str, ak.Array]: A dictionary with the dataset names as keys and
the corresponding predictions as values.
"""
category_inst = self.config_inst.get_category(self.branch_data.category)
Expand Down Expand Up @@ -1063,8 +1060,10 @@ def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]:
all_events[process_inst.name] = getattr(events, self.ml_model)
else:
if "process_ids" not in events.fields:
raise ValueError("No `process_ids` column stored in the events! "
f"Process selection for {dataset} cannot not be applied!")
raise ValueError(
"No `process_ids` column stored in the events! "
f"Process selection for {dataset} cannot not be applied!",
)
for sub_process in sub_process_insts[process_inst]:
if sub_process.name in self.skip_processes:
continue
Expand Down Expand Up @@ -1112,15 +1111,16 @@ def prepare_plot_parameters(self: PlotMLResults):

def output(self: PlotMLResults):
b = self.branch_data
return {"plots": [
self.target(name)
for name in self.get_plot_names(
f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0",
)
],
return {
"plots": [
self.target(name)
for name in self.get_plot_names(
f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0",
)
],
"array": self.target(
f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/data.parquet",
),
),
}

@law.decorator.log
Expand Down
3 changes: 1 addition & 2 deletions columnflow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from collections.abc import KeysView, ValuesView # noqa
from types import ModuleType, GeneratorType, GenericAlias # noqa
from typing import ( # noqa
Any, Union, TypeVar, ClassVar, List, Tuple, Sequence, Set, Dict, Callable, Generator, TextIO,
Iterable,
Any, Union, TypeVar, ClassVar, Sequence, Callable, Generator, TextIO, Iterable,
)

from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType # noqa
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/structure.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ all parameters are explained? If so, at least to be mentioned here.

It should also be added that there are additional parameters specific for the tasks in columnflow,
required by the fact that columnflow's purpose is for HEP analysis. These are the ```--analysis```
and ```-config``` parameters, which defaults can be set in the law.cfg. These two parameters
and ```--config``` parameters, which defaults can be set in the law.cfg. These two parameters
respectively define the config file for the different analyses to be used (where the different
analyses and their parameters should be defined) and the name of the config file for the specific
analysis to be used.
Expand Down

0 comments on commit c582628

Please sign in to comment.