From f51bcc6e4eca8df5558b8185e37b1b3880aff270 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 20 May 2024 16:22:36 +0200 Subject: [PATCH] check --- src/graphnet/training/labels.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 44433b370..190ddec7c 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -102,5 +102,7 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" - label = (torch.abs(graph[self._pid_key]) == 14) & (graph[self._int_key] == 1) + is_numu = torch.abs(graph[self._pid_key]) == 14 + is_cc = graph[self._int_key] == 1 + label = is_numu & is_cc return label.type(torch.int)