diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 2923ce4d4..87d36fd29 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -8,7 +8,7 @@ import re -from columnflow.types import Sequence, Dict, List, Tuple +from columnflow.types import Sequence from columnflow.util import maybe_import ak = maybe_import("awkward") @@ -19,7 +19,7 @@ hep = maybe_import("mplhep") colors = maybe_import("matplotlib.colors") -# Define a CF custom color maps +# define a CF custom color maps cf_colors = { "cf_green_cmap": colors.ListedColormap([ "#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", @@ -85,23 +85,26 @@ def plot_cm( cms_llabel: str = "private work", *args, **kwargs, -) -> Tuple[List[plt.Figure], np.ndarray]: - """ Generates the figure of the confusion matrix given the output of the nodes +) -> tuple[list[plt.Figure], np.ndarray]: + """ + Generates the figure of the confusion matrix given the output of the nodes and an array of true labels. The Cronfusion matrix can also be weighted. - :param events: dictionary with the true labels as keys and the model output of the events as values. - :param config_inst: used configuration for the plot. - :param category_inst: used category instance, for which the plot is created. - :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are - calculated based on the number of events. - :param normalization: type of normalization of the confusion matrix. If not provided, the matrix is row normalized. + :param events: Dictionary with the true labels as keys and the model output of the events as values. + :param config_inst: The used config instance, for which the plot is created. + :param category_inst: The used category instance, for which the plot is created. + :param sample_weights: Sample weights applied to the confusion matrix the events. + If an explicit array is not given, the weights are calculated based on the number of events when set to *True*. + :param normalization: The type of normalization of the confusion matrix. + This parameter takes 'row', 'col' or '' (empty string) as argument. + If not provided, the matrix is row normalized. :param skip_uncertainties: If true, no uncertainty of the cells will be shown in the plot. - :param x_labels: labels for the x-axis. - :param y_labels: labels for the y-axis. + :param x_labels: The labels for the x-axis. If not provided, the labels will be 'out' + :param y_labels: The labels for the y-axis. If not provided, the dataset names are used. :param *args: Additional arguments to pass to the function. :param **kwargs: Additional keyword arguments to pass to the function. - :return: The resulting plot and the confusion matrix. + :return: Returns the resulting plot and the confusion matrix. :raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights* is not *None* and its shape doesn't match *predictions*. @@ -135,13 +138,13 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0))) result = vecNumber(result, counts) - # Normalize Matrix if needed + # normalize Matrix if needed if normalization: valid = {"row": 1, "column": 0} if normalization not in valid.keys(): raise ValueError( - f"\"{normalization}\" is not a valid argument for normalization. If given, normalization " - "should only take \"row\" or \"column\"", + f"'{normalization}' is not a valid argument for normalization. If given, normalization " + "should only take 'row' or 'column'", ) row_sums = result.sum(axis=valid.get(normalization)) @@ -180,15 +183,15 @@ def plot_confusion_matrix( from mpl_toolkits.axes_grid1 import make_axes_locatable def calculate_font_size(): - # Get cell width + # get cell width bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) width, height = fig.dpi * bbox.width, fig.dpi * bbox.height - # Size of each cell in pixels + # size of each cell in pixels cell_width = width / n_classes cell_height = height / n_processes - # Calculate the font size based on the cell size to ensure font is not too large + # calculate the font size based on the cell size to ensure font is not too large font_size = min(cell_width, cell_height) / 10 font_size = max(min(font_size, 18), 8) @@ -218,7 +221,7 @@ def fmt(v): plt.style.use(hep.style.CMS) fig, ax = plt.subplots(dpi=300) - # Some useful variables and functions + # some useful variables and functions n_processes = cm.shape[0] n_classes = cm.shape[1] cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) @@ -228,11 +231,11 @@ def fmt(v): font_label = 20 font_text = calculate_font_size() - # Get values and (if available) their uncertenties + # get values and (if available) their uncertenties values = cm.astype(np.float32) uncs = get_errors(cm) - # Remove Major ticks and edit minor ticks + # remove Major ticks and edit minor ticks minor_tick_length = max(int(120 / n_classes), 12) / 2 minor_tick_width = max(6 / n_classes, 0.6) xtick_marks = np.arange(n_classes) @@ -241,7 +244,7 @@ def fmt(v): # plot the data im = ax.imshow(values, interpolation="nearest", cmap=cmap) - # Plot settings + # plot settings thresh = values.max() / 2. ax.tick_params(axis="both", which="major", bottom=False, top=False, left=False, right=False) ax.tick_params( @@ -271,7 +274,7 @@ def fmt(v): colorbar.ax.tick_params(labelsize=font_ax - 5) im.set_clim(0, max(1, values.max())) - # Add Matrix Elemtns + # add Matrix Elemtns for i in range(values.shape[0]): for j in range(values.shape[1]): ax.text( @@ -319,7 +322,7 @@ def plot_roc( cms_llabel: str = "private work", *args, **kwargs, -) -> tuple[List[plt.Figure], dict]: +) -> tuple[list[plt.Figure], dict]: """ Generates the figure of the ROC curve given the output of the nodes and an array of true labels. The ROC curve can also be weighted. @@ -331,7 +334,7 @@ def plot_roc( calculated based on the number of events. :param n_thresholds: number of thresholds to use for the ROC curve. :param skip_discriminators: list of discriminators to skip. - :param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is "OvR". + :param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'. :param cms_rlabel: right label of the CMS label. :param cms_llabel: left label of the CMS label. :param *args: Additional arguments to pass to the function. @@ -341,7 +344,7 @@ def plot_roc( :raises ValueError: If both predictions and labels have mismatched shapes, or if *weights* is not *None* and its shape doesn't match *predictions*. - :raises ValueError: If *normalization* is not one of *None*, "row", "column". + :raises ValueError: If *normalization* is not one of *None*, 'row', 'column'. """ # defining some useful properties and output shapes thresholds = np.linspace(0, 1, n_thresholds) @@ -350,10 +353,12 @@ def plot_roc( figs = [] if evaluation_type not in ["OvO", "OvR"]: - raise ValueError("Illeagal Argument! Evaluation Type can only be choosen as \"OvO\" (One vs One) \ - or \"OvR\" (One vs Rest)") + raise ValueError( + "Illeagal Argument! Evaluation Type can only be choosen as 'OvO' (One vs One)" + "or 'OvR' (One vs Rest)" + ) - def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> Dict[str, Dict[str, np.ndarray]]: + def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> dict[str, dict[str, np.ndarray]]: """ Helper function to create histograms for the different discriminators and classes. """ @@ -370,12 +375,12 @@ def binary_roc_data( negativ_hist: np.ndarray, *args, **kwargs, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """ Compute binary Receiver operating characteristic (ROC) values. Used as a helper function for the multi-dimensional ROC curve """ - # Calculate the different rates + # calculate the different rates fn = np.cumsum(positiv_hist) tn = np.cumsum(negativ_hist) tp = fn[-1] - fn @@ -391,7 +396,7 @@ def roc_curve_data( histograms: dict, *args, **kwargs, - ) -> Dict[str, Dict[str, np.ndarray]]: + ) -> dict[str, dict[str, np.ndarray]]: """ Compute Receiver operating characteristic (ROC) values for a multi-dimensional output. """ @@ -403,7 +408,7 @@ def roc_curve_data( if disc in skip_discriminators: continue - # Choose the evaluation type + # choose the evaluation type if (evaluation_type == "OvO"): for pos_cls, pos_hist in histograms[disc].items(): for neg_cls, neg_hist in histograms[disc].items(): diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index b52dee209..78e113355 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -5,9 +5,7 @@ """ from __future__ import annotations -from collections import OrderedDict - -from collections import defaultdict +from collections import OrderedDict, defaultdict import law import luigi @@ -32,7 +30,6 @@ from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict from columnflow.util import maybe_import -from columnflow.types import Dict, List ak = maybe_import("awkward") @@ -947,9 +944,9 @@ class PlotMLResultsBase( ) skip_processes = law.CSVParameter( - default=("",), + default=(), description="comma seperated list of process names to skip; these processes will not be included in the plots. " - "default: ('',)", + "default: ()", brace_expand=True, ) @@ -1005,14 +1002,14 @@ def workflow_requires(self: PlotMLResultsBase, only_super: bool = False): return reqs - def output(self: PlotMLResultsBase) -> Dict[str, List]: + def output(self: PlotMLResultsBase) -> dict[str, list]: b = self.branch_data return {"plots": [ self.target(name) for name in self.get_plot_names(f"plot__proc_{self.processes_repr}__cat_{b.category}") ]} - def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: + def prepare_inputs(self: PlotMLResultsBase) -> dict[str, ak.Array]: """ prepare the inputs for the plot function, based on the given configuration and category. @@ -1020,7 +1017,7 @@ def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: :raises ValueError: This error is raised if ``plot_sub_processes`` is used without providing the ``process_ids`` column in the data - :return: Dict[str, ak.Array]: A dictionary with the dataset names as keys and + :return: dict[str, ak.Array]: A dictionary with the dataset names as keys and the corresponding predictions as values. """ category_inst = self.config_inst.get_category(self.branch_data.category) @@ -1063,8 +1060,10 @@ def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: all_events[process_inst.name] = getattr(events, self.ml_model) else: if "process_ids" not in events.fields: - raise ValueError("No `process_ids` column stored in the events! " - f"Process selection for {dataset} cannot not be applied!") + raise ValueError( + "No `process_ids` column stored in the events! " + f"Process selection for {dataset} cannot not be applied!", + ) for sub_process in sub_process_insts[process_inst]: if sub_process.name in self.skip_processes: continue @@ -1112,15 +1111,16 @@ def prepare_plot_parameters(self: PlotMLResults): def output(self: PlotMLResults): b = self.branch_data - return {"plots": [ - self.target(name) - for name in self.get_plot_names( - f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", - ) - ], + return { + "plots": [ + self.target(name) + for name in self.get_plot_names( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", + ) + ], "array": self.target( f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/data.parquet", - ), + ), } @law.decorator.log diff --git a/columnflow/types.py b/columnflow/types.py index 9d4d96e17..2641c46f8 100644 --- a/columnflow/types.py +++ b/columnflow/types.py @@ -22,8 +22,7 @@ from collections.abc import KeysView, ValuesView # noqa from types import ModuleType, GeneratorType, GenericAlias # noqa from typing import ( # noqa - Any, Union, TypeVar, ClassVar, List, Tuple, Sequence, Set, Dict, Callable, Generator, TextIO, - Iterable, + Any, Union, TypeVar, ClassVar, Sequence, Callable, Generator, TextIO, Iterable, ) from typing_extensions import Annotated, _AnnotatedAlias as AnnotatedType # noqa diff --git a/docs/user_guide/structure.md b/docs/user_guide/structure.md index 63e872fe5..e44167104 100644 --- a/docs/user_guide/structure.md +++ b/docs/user_guide/structure.md @@ -190,7 +190,7 @@ all parameters are explained? If so, at least to be mentioned here. It should also be added that there are additional parameters specific for the tasks in columnflow, required by the fact that columnflow's purpose is for HEP analysis. These are the ```--analysis``` -and ```-config``` parameters, which defaults can be set in the law.cfg. These two parameters +and ```--config``` parameters, which defaults can be set in the law.cfg. These two parameters respectively define the config file for the different analyses to be used (where the different analyses and their parameters should be defined) and the name of the config file for the specific analysis to be used.