Skip to content

Commit

Permalink
More docstring fixes, simplifying graph calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
pweigel committed Jan 6, 2025
1 parent 9673f50 commit aa70dc0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 39 deletions.
7 changes: 3 additions & 4 deletions src/graphnet/models/components/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from torch_scatter import scatter

from pytorch_lightning import LightningModule

from graphnet.models.utils import get_log_deg
from torch_geometric.utils import degree


class DynEdgeConv(EdgeConv, LightningModule):
Expand Down Expand Up @@ -848,7 +847,7 @@ def __init__(
if norm_edges
else nn.Identity()
)
else: # TODO: Maybe just set this to nn.Identity. -PW
else:
raise ValueError(
"GritTransformerLayer normalization layer must be 'LayerNorm' \
or 'BatchNorm1d'!"
Expand Down Expand Up @@ -881,7 +880,7 @@ def forward(self, data: Data) -> Data:
"""Forward pass."""
x = data.x
num_nodes = data.num_nodes
log_deg = get_log_deg(data)
log_deg = torch.log10(degree(data.edge_index[0]) + 1)

x_attn_residual = x # for first residual connection
e_values_in = data.get("edge_attr", None)
Expand Down
12 changes: 8 additions & 4 deletions src/graphnet/models/gnn/grit.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,22 @@ def __init__(
add_node_attr_as_self_loop: Adds node attr as an self-edge.
dropout: Dropout probability.
fill_value: Padding value.
norm: Normalization layer.
norm: Uninstantiated normalization layer.
Either `torch.nn.BatchNorm1d` or `torch.nn.LayerNorm`.
attn_dropout: Attention dropout probability.
edge_enhance: Applies learnable weight matrix with node-pair in
output node calculation for MHA.
update_edges: Update edge values after GRIT layer.
attn_clamp: Clamp absolute value of attention scores to a value.
activation: Activation function.
attn_activation: Attention activation function.
activation: Uninstantiated activation function.
E.g. `torch.nn.ReLU`
attn_activation: Uninstantiated attention activation function.
E.g. `torch.nn.ReLU`
norm_edges: Apply normalization layer to edges.
enable_edge_transform: Apply transformation to edges.
pred_head_layers: Number of layers in the prediction head.
pred_head_activation: Prediction head activation function.
pred_head_activation: Uninstantiated prediction head activation
function. E.g. `torch.nn.ReLU`
pred_head_pooling: Pooling function to use for the prediction head,
either "mean" (default) or "add".
position_encoding: Method of position encoding.
Expand Down
9 changes: 4 additions & 5 deletions src/graphnet/models/graphs/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ class KNNGraphRRWP(GraphDefinition):
Identical to KNNGraph, but with five extra fields containing absolute and
relative positional encoding using RRWP.
``` abs_pe = graph["rrwp"] # RRWP absolute positional encoding values
rrwp_val = graph["rrwp_val"] # Non-zero values of the RRWP tensor
rrwp_index = graph["rrwp_index] # Corresponding row, col indices degree =
graph["deg"] # Degree of each node (num. of incoming edges) log_deg =
graph["log_deg"] # Equal to torch.log10(graph["deg"] + 1) ```
`abs_pe = graph["rrwp"] # RRWP absolute positional encoding values`
`rrwp_val = graph["rrwp_val"] # Non-zero values of the RRWP tensor`
`rrwp_index = graph["rrwp_index] # Corresponding row, col indices` `degree
= graph["deg"] # Degree of each node (num. of incoming edges)`
"""

def __init__(
Expand Down
27 changes: 1 addition & 26 deletions src/graphnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch_geometric.nn import knn_graph
from torch_geometric.data import Batch, Data
from torch_geometric.utils import homophily, degree, to_dense_adj
from torch_geometric.utils import homophily, to_dense_adj
from torch_geometric.utils.num_nodes import maybe_num_nodes

from torch_scatter import scatter, scatter_add
Expand Down Expand Up @@ -195,7 +195,6 @@ def add_full_rrwp(
)

# Compute D^{-1} A:
deg = adj.sum(dim=1)
deg_inv = 1.0 / adj.sum(dim=1)
deg_inv[deg_inv == float("inf")] = 0
adj = adj * deg_inv.view(-1, 1)
Expand Down Expand Up @@ -239,33 +238,9 @@ def add_full_rrwp(
data[f"{attr_name_rel}_index"] = rel_pe_idx
data[f"{attr_name_rel}_val"] = rel_pe_val

data.log_deg = torch.log(deg + 1)
data.deg = deg.type(torch.long)

return data


@torch.no_grad()
def get_log_deg(data: Data) -> Tensor:
"""Get log of the degree number of a graph.
Original code:
https://github.com/LiamMa/GRIT/blob/main/grit/layer/grit_layer.py
"""
if "log_deg" in data:
log_deg = data.log_deg
elif "deg" in data:
deg = data.deg
log_deg = torch.log(deg + 1).unsqueeze(-1)
else:
deg = degree(
data.edge_index[1], num_nodes=data.num_nodes, dtype=data.x.dtype
)
log_deg = torch.log(deg + 1)
log_deg = log_deg.view(data.num_nodes, 1)
return log_deg


def get_rw_landing_probs(
ksteps: List,
edge_index: Tensor,
Expand Down

0 comments on commit aa70dc0

Please sign in to comment.