From aa70dc049c472ffcfdd437d92fd30c04b7bc73ad Mon Sep 17 00:00:00 2001 From: Philip Weigel Date: Mon, 6 Jan 2025 15:07:46 -0500 Subject: [PATCH] More docstring fixes, simplifying graph calculations --- src/graphnet/models/components/layers.py | 7 +++--- src/graphnet/models/gnn/grit.py | 12 +++++++---- src/graphnet/models/graphs/graphs.py | 9 ++++---- src/graphnet/models/utils.py | 27 +----------------------- 4 files changed, 16 insertions(+), 39 deletions(-) diff --git a/src/graphnet/models/components/layers.py b/src/graphnet/models/components/layers.py index 23b440ac3..992d1284d 100644 --- a/src/graphnet/models/components/layers.py +++ b/src/graphnet/models/components/layers.py @@ -21,8 +21,7 @@ from torch_scatter import scatter from pytorch_lightning import LightningModule - -from graphnet.models.utils import get_log_deg +from torch_geometric.utils import degree class DynEdgeConv(EdgeConv, LightningModule): @@ -848,7 +847,7 @@ def __init__( if norm_edges else nn.Identity() ) - else: # TODO: Maybe just set this to nn.Identity. -PW + else: raise ValueError( "GritTransformerLayer normalization layer must be 'LayerNorm' \ or 'BatchNorm1d'!" @@ -881,7 +880,7 @@ def forward(self, data: Data) -> Data: """Forward pass.""" x = data.x num_nodes = data.num_nodes - log_deg = get_log_deg(data) + log_deg = torch.log10(degree(data.edge_index[0]) + 1) x_attn_residual = x # for first residual connection e_values_in = data.get("edge_attr", None) diff --git a/src/graphnet/models/gnn/grit.py b/src/graphnet/models/gnn/grit.py index c5edf3d0b..f42928a3f 100644 --- a/src/graphnet/models/gnn/grit.py +++ b/src/graphnet/models/gnn/grit.py @@ -73,18 +73,22 @@ def __init__( add_node_attr_as_self_loop: Adds node attr as an self-edge. dropout: Dropout probability. fill_value: Padding value. - norm: Normalization layer. + norm: Uninstantiated normalization layer. + Either `torch.nn.BatchNorm1d` or `torch.nn.LayerNorm`. attn_dropout: Attention dropout probability. edge_enhance: Applies learnable weight matrix with node-pair in output node calculation for MHA. update_edges: Update edge values after GRIT layer. attn_clamp: Clamp absolute value of attention scores to a value. - activation: Activation function. - attn_activation: Attention activation function. + activation: Uninstantiated activation function. + E.g. `torch.nn.ReLU` + attn_activation: Uninstantiated attention activation function. + E.g. `torch.nn.ReLU` norm_edges: Apply normalization layer to edges. enable_edge_transform: Apply transformation to edges. pred_head_layers: Number of layers in the prediction head. - pred_head_activation: Prediction head activation function. + pred_head_activation: Uninstantiated prediction head activation + function. E.g. `torch.nn.ReLU` pred_head_pooling: Pooling function to use for the prediction head, either "mean" (default) or "add". position_encoding: Method of position encoding. diff --git a/src/graphnet/models/graphs/graphs.py b/src/graphnet/models/graphs/graphs.py index 4303dc20a..8d7b318bb 100644 --- a/src/graphnet/models/graphs/graphs.py +++ b/src/graphnet/models/graphs/graphs.py @@ -116,11 +116,10 @@ class KNNGraphRRWP(GraphDefinition): Identical to KNNGraph, but with five extra fields containing absolute and relative positional encoding using RRWP. - ``` abs_pe = graph["rrwp"] # RRWP absolute positional encoding values - rrwp_val = graph["rrwp_val"] # Non-zero values of the RRWP tensor - rrwp_index = graph["rrwp_index] # Corresponding row, col indices degree = - graph["deg"] # Degree of each node (num. of incoming edges) log_deg = - graph["log_deg"] # Equal to torch.log10(graph["deg"] + 1) ``` + `abs_pe = graph["rrwp"] # RRWP absolute positional encoding values` + `rrwp_val = graph["rrwp_val"] # Non-zero values of the RRWP tensor` + `rrwp_index = graph["rrwp_index] # Corresponding row, col indices` `degree + = graph["deg"] # Degree of each node (num. of incoming edges)` """ def __init__( diff --git a/src/graphnet/models/utils.py b/src/graphnet/models/utils.py index e64f520de..06cf41f8c 100644 --- a/src/graphnet/models/utils.py +++ b/src/graphnet/models/utils.py @@ -7,7 +7,7 @@ from torch_geometric.nn import knn_graph from torch_geometric.data import Batch, Data -from torch_geometric.utils import homophily, degree, to_dense_adj +from torch_geometric.utils import homophily, to_dense_adj from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_scatter import scatter, scatter_add @@ -195,7 +195,6 @@ def add_full_rrwp( ) # Compute D^{-1} A: - deg = adj.sum(dim=1) deg_inv = 1.0 / adj.sum(dim=1) deg_inv[deg_inv == float("inf")] = 0 adj = adj * deg_inv.view(-1, 1) @@ -239,33 +238,9 @@ def add_full_rrwp( data[f"{attr_name_rel}_index"] = rel_pe_idx data[f"{attr_name_rel}_val"] = rel_pe_val - data.log_deg = torch.log(deg + 1) - data.deg = deg.type(torch.long) - return data -@torch.no_grad() -def get_log_deg(data: Data) -> Tensor: - """Get log of the degree number of a graph. - - Original code: - https://github.com/LiamMa/GRIT/blob/main/grit/layer/grit_layer.py - """ - if "log_deg" in data: - log_deg = data.log_deg - elif "deg" in data: - deg = data.deg - log_deg = torch.log(deg + 1).unsqueeze(-1) - else: - deg = degree( - data.edge_index[1], num_nodes=data.num_nodes, dtype=data.x.dtype - ) - log_deg = torch.log(deg + 1) - log_deg = log_deg.view(data.num_nodes, 1) - return log_deg - - def get_rw_landing_probs( ksteps: List, edge_index: Tensor,