From b31aa42c056fb521531312dff72d61a89b80c8b7 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 20 May 2024 16:10:49 +0200 Subject: [PATCH 1/3] add missing `abs` to track label --- src/graphnet/training/labels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 11129f915..44433b370 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -102,5 +102,5 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" - label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) + label = (torch.abs(graph[self._pid_key]) == 14) & (graph[self._int_key] == 1) return label.type(torch.int) From f51bcc6e4eca8df5558b8185e37b1b3880aff270 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 20 May 2024 16:22:36 +0200 Subject: [PATCH 2/3] check --- src/graphnet/training/labels.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 44433b370..190ddec7c 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -102,5 +102,7 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" - label = (torch.abs(graph[self._pid_key]) == 14) & (graph[self._int_key] == 1) + is_numu = torch.abs(graph[self._pid_key]) == 14 + is_cc = graph[self._int_key] == 1 + label = is_numu & is_cc return label.type(torch.int) From 1ee2b53ee2af1d8b0116b0805b450c5023bd66b0 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 20 May 2024 16:28:10 +0200 Subject: [PATCH 3/3] shorten def --- src/graphnet/training/labels.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 190ddec7c..cd5e5f663 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -104,5 +104,4 @@ def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" is_numu = torch.abs(graph[self._pid_key]) == 14 is_cc = graph[self._int_key] == 1 - label = is_numu & is_cc - return label.type(torch.int) + return (is_numu & is_cc).type(torch.int)