diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 11129f915..cd5e5f663 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -102,5 +102,6 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" - label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return label.type(torch.int) + is_numu = torch.abs(graph[self._pid_key]) == 14 + is_cc = graph[self._int_key] == 1 + return (is_numu & is_cc).type(torch.int)