From a560e84dfdac6ebcab709802b6f5ba423cc11ca0 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 8 Dec 2023 13:11:44 +0100 Subject: [PATCH] implemented comments --- columnflow/plotting/plot_ml_evaluation.py | 102 +++++++++++++++------- tests/__init__.py | 2 + tests/test_plotting.py | 11 +-- 3 files changed, 76 insertions(+), 39 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index e55940fe4..ff243e3c7 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -72,6 +72,39 @@ def create_sample_weights( return dict(zip(true_labels, sample_weights)) +def get_cms_config(ax: plt.Axes, llabel: str, rlabel: str) -> dict: + """ + Helper function to get the CMS label configuration. + + :param ax: The axis to plot the CMS label on. + :param llabel: The left label of the CMS label. + :param rlabel: The right label of the CMS label. + :return: A dictionary with the CMS label configuration. + """ + label_options = { + "wip": "Work in progress", + "pre": "Preliminary", + "pw": "Private work", + "sim": "Simulation", + "simwip": "Simulation work in progress", + "simpre": "Simulation preliminary", + "simpw": "Simulation private work", + "od": "OpenData", + "odwip": "OpenData work in progress", + "odpw": "OpenData private work", + "public": "", + } + cms_label_kwargs = { + "ax": ax, + "llabel": label_options.get(llabel, llabel), + "rlabel": rlabel, + "fontsize": 22, + "data": False, + } + + return cms_label_kwargs + + def plot_cm( events: dict, config_inst: od.Config, @@ -82,7 +115,7 @@ def plot_cm( x_labels: list[str] | None = None, y_labels: list[str] | None = None, cms_rlabel: str = "", - cms_llabel: str = "private work", + cms_llabel: str = "pw", *args, **kwargs, ) -> tuple[list[plt.Figure], np.ndarray]: @@ -288,13 +321,14 @@ def fmt(v): ) # final touches - hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) + if cms_llabel != "skip": + cms_label_kwargs = get_cms_config(ax=ax, llabel=cms_llabel, rlabel=cms_rlabel) + hep.cms.label(**cms_label_kwargs) plt.tight_layout() return fig cm = get_conf_matrix(sample_weights, *args, **kwargs) - print("Confusion matrix calculated!") fig = plot_confusion_matrix( cm, @@ -316,7 +350,7 @@ def plot_roc( category_inst: od.Category, sample_weights: Sequence | bool = False, n_thresholds: int = 200 + 1, - skip_discriminators: list[str] = [], + skip_discriminators: list[str] | None = None, evaluation_type: str = "OvR", cms_rlabel: str = "", cms_llabel: str = "private work", @@ -327,16 +361,16 @@ def plot_roc( Generates the figure of the ROC curve given the output of the nodes and an array of true labels. The ROC curve can also be weighted. - :param events: dictionary with the true labels as keys and the model output of the events as values. - :param config_inst: used configuration for the plot. - :param category_inst: used category instance, for which the plot is created. - :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are + :param events: Dictionary with the true labels as keys and the model output of the events as values. + :param config_inst: Used configuration for the plot. + :param category_inst: Used category instance, for which the plot is created. + :param sample_weights: Sample weights of the events. If an explicit array is not given, the weights are calculated based on the number of events. - :param n_thresholds: number of thresholds to use for the ROC curve. - :param skip_discriminators: list of discriminators to skip. - :param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'. - :param cms_rlabel: right label of the CMS label. - :param cms_llabel: left label of the CMS label. + :param n_thresholds: Number of thresholds to use for the ROC curve. + :param skip_discriminators: List of discriminators to skip. + :param evaluation_type: Type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'. + :param cms_rlabel: Right label of the CMS label. + :param cms_llabel: Left label of the CMS label. :param *args: Additional arguments to pass to the function. :param **kwargs: Additional keyword arguments to pass to the function. @@ -351,6 +385,8 @@ def plot_roc( weights = create_sample_weights(sample_weights, events, list(events.keys())) discriminators = list(events.values())[0].fields figs = [] + if skip_discriminators is None: + skip_discriminators = [] if evaluation_type not in ["OvO", "OvR"]: raise ValueError( @@ -366,8 +402,10 @@ def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> di for disc in discriminators: hists[disc] = {} for cls, predictions in events.items(): - hists[disc][cls] = (sample_weights[cls] * - ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0])) + hists[disc][cls] = ( + sample_weights[cls] * + ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]) + ) return hists def binary_roc_data( @@ -403,16 +441,17 @@ def roc_curve_data( result = {} for disc in discriminators: - tmp = {} - if disc in skip_discriminators: continue + tmp = {} + # choose the evaluation type if (evaluation_type == "OvO"): for pos_cls, pos_hist in histograms[disc].items(): for neg_cls, neg_hist in histograms[disc].items(): - + if (pos_cls == neg_cls): + continue fpr, tpr = binary_roc_data( positiv_hist=pos_hist, negativ_hist=neg_hist, @@ -423,11 +462,13 @@ def roc_curve_data( elif (evaluation_type == "OvR"): for pos_cls, pos_hist in histograms[disc].items(): - neg_hist = np.zeros_like(pos_hist) - for neg_cls, neg_pred in histograms[disc].items(): - if (pos_cls == neg_cls): - continue - neg_hist += neg_pred + neg_hist = sum( + ( + neg_pred + for neg_cls, neg_pred in histograms[disc].items() + if neg_cls != pos_cls + ), + ) fpr, tpr = binary_roc_data( positiv_hist=pos_hist, @@ -445,7 +486,7 @@ def plot_roc_curve( roc_data: dict, title: str, cms_rlabel: str = "", - cms_llabel: str = "private work", + cms_llabel: str = "pw", *args, **kwargs, ) -> plt.figure: @@ -458,10 +499,9 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: """ sign = 1 if np.any(np.diff(fpr) < 0): - if np.all(np.diff(fpr) <= 0): - sign = -1 - else: - raise ValueError("x is neither increasing nor decreasing : {}.".format(fpr)) + if np.all(np.diff(fpr) > 0): + raise ValueError(f"x is neither increasing nor decreasing : {fpr}.") + sign = -1 return sign * np.trapz(tpr, fpr) @@ -486,17 +526,17 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: # setting titles and legend ax.legend(loc="lower right", fontsize=15) ax.set_title(title, wrap=True, pad=50, fontsize=30) - hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) + if cms_llabel != "skip": + cms_label_kwargs = get_cms_config(ax=ax, llabel=cms_llabel, rlabel=cms_rlabel) + hep.cms.label(**cms_label_kwargs) plt.tight_layout() return fig # create historgrams and calculate FPR and TPR histograms = create_histograms(events, weights, *args, **kwargs) - print("histograms created!") results = roc_curve_data(evaluation_type, histograms, *args, **kwargs) - print("ROC data calculated!") # plotting for disc, roc in results.items(): diff --git a/tests/__init__.py b/tests/__init__.py index 1573fa31e..92e51bd50 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -17,3 +17,5 @@ # import all tests from .test_util import * from .test_columnar_util import * +from .test_task_parameters import * +from .test_plotting import * diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 80e7d8da5..4e16c59f7 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -202,7 +202,7 @@ def setUp(self): "out2": [0.8, 0.8, 0.6, 0.2], }), } - self.N_discriminators = 2 + self.n_discriminators = 2 self.config_inst = MagicMock() self.category_inst = MagicMock() # The following results are calculated by hand @@ -214,7 +214,6 @@ def setUp(self): "out2": { "fpr": [0.75, 0.75, 0.5, 0], "tpr": [0.5, 0.25, 0, 0], - }, } self.text_trap = io.StringIO() @@ -230,8 +229,8 @@ def test_plot_roc_returns_correct_number_of_figures(self): figs_ovr, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR") figs_ovo, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvO") - self.assertEqual(len(figs_ovr), self.N_discriminators * len(self.events)) - self.assertEqual(len(figs_ovo), self.N_discriminators * len(self.events) * (len(self.events))) + self.assertEqual(len(figs_ovr), self.n_discriminators * len(self.events)) + self.assertEqual(len(figs_ovo), self.n_discriminators * len(self.events) * (len(self.events) - 1)) def test_plot_roc_returns_correct_results(self): with redirect_stdout(self.text_trap): @@ -254,7 +253,3 @@ def test_plot_roc_returns_correct_results(self): def test_plot_roc_raises_value_error_for_invalid_evaluation_type(self): with self.assertRaises(ValueError), redirect_stdout(self.text_trap): plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="InvalidType") - - -if __name__ == "__main__": - unittest.main()