Skip to content

Commit

Permalink
implemented comments
Browse files Browse the repository at this point in the history
  • Loading branch information
haddadanas committed Dec 8, 2023
1 parent b46f3c8 commit a560e84
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 39 deletions.
102 changes: 71 additions & 31 deletions columnflow/plotting/plot_ml_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,39 @@ def create_sample_weights(
return dict(zip(true_labels, sample_weights))


def get_cms_config(ax: plt.Axes, llabel: str, rlabel: str) -> dict:
"""
Helper function to get the CMS label configuration.
:param ax: The axis to plot the CMS label on.
:param llabel: The left label of the CMS label.
:param rlabel: The right label of the CMS label.
:return: A dictionary with the CMS label configuration.
"""
label_options = {
"wip": "Work in progress",
"pre": "Preliminary",
"pw": "Private work",
"sim": "Simulation",
"simwip": "Simulation work in progress",
"simpre": "Simulation preliminary",
"simpw": "Simulation private work",
"od": "OpenData",
"odwip": "OpenData work in progress",
"odpw": "OpenData private work",
"public": "",
}
cms_label_kwargs = {
"ax": ax,
"llabel": label_options.get(llabel, llabel),
"rlabel": rlabel,
"fontsize": 22,
"data": False,
}

return cms_label_kwargs


def plot_cm(
events: dict,
config_inst: od.Config,
Expand All @@ -82,7 +115,7 @@ def plot_cm(
x_labels: list[str] | None = None,
y_labels: list[str] | None = None,
cms_rlabel: str = "",
cms_llabel: str = "private work",
cms_llabel: str = "pw",
*args,
**kwargs,
) -> tuple[list[plt.Figure], np.ndarray]:
Expand Down Expand Up @@ -288,13 +321,14 @@ def fmt(v):
)

# final touches
hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel)
if cms_llabel != "skip":
cms_label_kwargs = get_cms_config(ax=ax, llabel=cms_llabel, rlabel=cms_rlabel)
hep.cms.label(**cms_label_kwargs)
plt.tight_layout()

return fig

cm = get_conf_matrix(sample_weights, *args, **kwargs)
print("Confusion matrix calculated!")

fig = plot_confusion_matrix(
cm,
Expand All @@ -316,7 +350,7 @@ def plot_roc(
category_inst: od.Category,
sample_weights: Sequence | bool = False,
n_thresholds: int = 200 + 1,
skip_discriminators: list[str] = [],
skip_discriminators: list[str] | None = None,
evaluation_type: str = "OvR",
cms_rlabel: str = "",
cms_llabel: str = "private work",
Expand All @@ -327,16 +361,16 @@ def plot_roc(
Generates the figure of the ROC curve given the output of the nodes
and an array of true labels. The ROC curve can also be weighted.
:param events: dictionary with the true labels as keys and the model output of the events as values.
:param config_inst: used configuration for the plot.
:param category_inst: used category instance, for which the plot is created.
:param sample_weights: sample weights of the events. If an explicit array is not given, the weights are
:param events: Dictionary with the true labels as keys and the model output of the events as values.
:param config_inst: Used configuration for the plot.
:param category_inst: Used category instance, for which the plot is created.
:param sample_weights: Sample weights of the events. If an explicit array is not given, the weights are
calculated based on the number of events.
:param n_thresholds: number of thresholds to use for the ROC curve.
:param skip_discriminators: list of discriminators to skip.
:param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'.
:param cms_rlabel: right label of the CMS label.
:param cms_llabel: left label of the CMS label.
:param n_thresholds: Number of thresholds to use for the ROC curve.
:param skip_discriminators: List of discriminators to skip.
:param evaluation_type: Type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'.
:param cms_rlabel: Right label of the CMS label.
:param cms_llabel: Left label of the CMS label.
:param *args: Additional arguments to pass to the function.
:param **kwargs: Additional keyword arguments to pass to the function.
Expand All @@ -351,6 +385,8 @@ def plot_roc(
weights = create_sample_weights(sample_weights, events, list(events.keys()))
discriminators = list(events.values())[0].fields
figs = []
if skip_discriminators is None:
skip_discriminators = []

if evaluation_type not in ["OvO", "OvR"]:
raise ValueError(
Expand All @@ -366,8 +402,10 @@ def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> di
for disc in discriminators:
hists[disc] = {}
for cls, predictions in events.items():
hists[disc][cls] = (sample_weights[cls] *
ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]))
hists[disc][cls] = (
sample_weights[cls] *
ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0])
)
return hists

def binary_roc_data(
Expand Down Expand Up @@ -403,16 +441,17 @@ def roc_curve_data(
result = {}

for disc in discriminators:
tmp = {}

if disc in skip_discriminators:
continue

tmp = {}

# choose the evaluation type
if (evaluation_type == "OvO"):
for pos_cls, pos_hist in histograms[disc].items():
for neg_cls, neg_hist in histograms[disc].items():

if (pos_cls == neg_cls):
continue
fpr, tpr = binary_roc_data(
positiv_hist=pos_hist,
negativ_hist=neg_hist,
Expand All @@ -423,11 +462,13 @@ def roc_curve_data(

elif (evaluation_type == "OvR"):
for pos_cls, pos_hist in histograms[disc].items():
neg_hist = np.zeros_like(pos_hist)
for neg_cls, neg_pred in histograms[disc].items():
if (pos_cls == neg_cls):
continue
neg_hist += neg_pred
neg_hist = sum(
(
neg_pred
for neg_cls, neg_pred in histograms[disc].items()
if neg_cls != pos_cls
),
)

fpr, tpr = binary_roc_data(
positiv_hist=pos_hist,
Expand All @@ -445,7 +486,7 @@ def plot_roc_curve(
roc_data: dict,
title: str,
cms_rlabel: str = "",
cms_llabel: str = "private work",
cms_llabel: str = "pw",
*args,
**kwargs,
) -> plt.figure:
Expand All @@ -458,10 +499,9 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64:
"""
sign = 1
if np.any(np.diff(fpr) < 0):
if np.all(np.diff(fpr) <= 0):
sign = -1
else:
raise ValueError("x is neither increasing nor decreasing : {}.".format(fpr))
if np.all(np.diff(fpr) > 0):
raise ValueError(f"x is neither increasing nor decreasing : {fpr}.")
sign = -1

return sign * np.trapz(tpr, fpr)

Expand All @@ -486,17 +526,17 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64:
# setting titles and legend
ax.legend(loc="lower right", fontsize=15)
ax.set_title(title, wrap=True, pad=50, fontsize=30)
hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel)
if cms_llabel != "skip":
cms_label_kwargs = get_cms_config(ax=ax, llabel=cms_llabel, rlabel=cms_rlabel)
hep.cms.label(**cms_label_kwargs)
plt.tight_layout()

return fig

# create historgrams and calculate FPR and TPR
histograms = create_histograms(events, weights, *args, **kwargs)
print("histograms created!")

results = roc_curve_data(evaluation_type, histograms, *args, **kwargs)
print("ROC data calculated!")

# plotting
for disc, roc in results.items():
Expand Down
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
# import all tests
from .test_util import *
from .test_columnar_util import *
from .test_task_parameters import *
from .test_plotting import *
11 changes: 3 additions & 8 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def setUp(self):
"out2": [0.8, 0.8, 0.6, 0.2],
}),
}
self.N_discriminators = 2
self.n_discriminators = 2
self.config_inst = MagicMock()
self.category_inst = MagicMock()
# The following results are calculated by hand
Expand All @@ -214,7 +214,6 @@ def setUp(self):
"out2": {
"fpr": [0.75, 0.75, 0.5, 0],
"tpr": [0.5, 0.25, 0, 0],

},
}
self.text_trap = io.StringIO()
Expand All @@ -230,8 +229,8 @@ def test_plot_roc_returns_correct_number_of_figures(self):
figs_ovr, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR")
figs_ovo, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvO")

self.assertEqual(len(figs_ovr), self.N_discriminators * len(self.events))
self.assertEqual(len(figs_ovo), self.N_discriminators * len(self.events) * (len(self.events)))
self.assertEqual(len(figs_ovr), self.n_discriminators * len(self.events))
self.assertEqual(len(figs_ovo), self.n_discriminators * len(self.events) * (len(self.events) - 1))

def test_plot_roc_returns_correct_results(self):
with redirect_stdout(self.text_trap):
Expand All @@ -254,7 +253,3 @@ def test_plot_roc_returns_correct_results(self):
def test_plot_roc_raises_value_error_for_invalid_evaluation_type(self):
with self.assertRaises(ValueError), redirect_stdout(self.text_trap):
plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="InvalidType")


if __name__ == "__main__":
unittest.main()

0 comments on commit a560e84

Please sign in to comment.