Skip to content

Commit

Permalink
Merge pull request #676 from ChenLi2049/iseecube
Browse files Browse the repository at this point in the history
ISeeCube implentation
  • Loading branch information
chenlinear authored Apr 18, 2024
2 parents 2248da4 + 007a0d7 commit 23581d8
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 9 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"tqdm>=4.64",
"wandb>=0.12",
"polars >=0.19",
"torchscale==0.2.0",
"h5py>= 3.7.0",
]

Expand Down
23 changes: 17 additions & 6 deletions src/graphnet/models/components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch.nn as nn
from torch.functional import Tensor

from typing import Optional

from pytorch_lightning import LightningModule


Expand Down Expand Up @@ -61,6 +63,7 @@ class FourierEncoder(LightningModule):
def __init__(
self,
seq_length: int = 128,
mlp_dim: Optional[int] = None,
output_dim: int = 384,
scaled: bool = False,
n_features: int = 6,
Expand All @@ -70,7 +73,11 @@ def __init__(
Args:
seq_length: Dimensionality of the base sinusoidal positional
embeddings.
output_dim: Output dimensionality of the final projection.
mlp_dim (Optional): Size of hidden, latent space of MLP. If not
given, `mlp_dim` is set automatically as multiples of
`seq_length` (in consistent with the 2nd place solution),
depending on `n_features`.
output_dim: Dimension of the output (I.e. number of columns).
scaled: Whether or not to scale the embeddings.
n_features: The number of features in the input data.
"""
Expand All @@ -90,11 +97,14 @@ def __init__(
else:
hidden_dim = int((n_features + 0.5) * seq_length)

self.projection = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
if mlp_dim is None:
mlp_dim = hidden_dim

self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim),
nn.LayerNorm(mlp_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim),
nn.Linear(mlp_dim, output_dim),
)

self.n_features = n_features
Expand All @@ -121,7 +131,8 @@ def forward(
) # Length

x = torch.cat(embeddings, -1)
x = self.projection(x)
x = self.mlp(x)

return x


Expand Down
9 changes: 6 additions & 3 deletions src/graphnet/models/gnn/icemix.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class DeepIce(GNN):
def __init__(
self,
hidden_dim: int = 384,
mlp_ratio: int = 4,
seq_length: int = 192,
depth: int = 12,
head_size: int = 32,
Expand All @@ -48,6 +49,7 @@ def __init__(
Args:
hidden_dim: The latent feature dimension.
mlp_ratio: Mlp expansion ratio of FourierEncoder and Transformer.
seq_length: The base feature dimension.
depth: The depth of the transformer.
head_size: The size of the attention heads.
Expand All @@ -65,8 +67,9 @@ def __init__(
super().__init__(seq_length, hidden_dim)
fourier_out_dim = hidden_dim // 2 if include_dynedge else hidden_dim
self.fourier_ext = FourierEncoder(
seq_length,
fourier_out_dim,
seq_length=seq_length,
mlp_dim=None,
output_dim=fourier_out_dim,
scaled=scaled_emb,
n_features=n_features,
)
Expand All @@ -85,7 +88,7 @@ def __init__(
Block(
input_dim=hidden_dim,
num_heads=hidden_dim // head_size,
mlp_ratio=4,
mlp_ratio=mlp_ratio,
drop_path=0.0 * (i / (depth - 1)),
init_values=1,
)
Expand Down
3 changes: 3 additions & 0 deletions src/graphnet/models/transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Transformer-specific modules."""

from .iseecube import ISeeCube
102 changes: 102 additions & 0 deletions src/graphnet/models/transformer/iseecube.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Implementation of ISeeCube Transformer architecture used in.
https://github.com/ChenLi2049/ISeeCube/
"""

import torch
import torch.nn as nn

from graphnet.models.components.embedding import FourierEncoder
from graphnet.models.gnn.gnn import GNN
from graphnet.models.utils import array_to_sequence

from torchscale.architecture.config import EncoderConfig
from torchscale.architecture.encoder import Encoder

from torch_geometric.data import Data
from torch import Tensor


class ISeeCube(GNN):
"""ISeeCube model."""

def __init__(
self,
hidden_dim: int = 384,
seq_length: int = 196,
num_layers: int = 16,
num_heads: int = 12,
mlp_dim: int = 1536,
rel_pos_buckets: int = 32,
max_rel_pos: int = 256,
num_register_tokens: int = 3,
scaled_emb: bool = False,
n_features: int = 6,
):
"""Construct `ISeeCube`.
Args:
hidden_dim: The latent feature dimension.
seq_length: The number of pulses in a neutrino event.
num_layers: The depth of the transformer.
num_heads: The number of the attention heads.
mlp_dim: The mlp dimension of FourierEncoder and Transformer.
rel_pos_buckets: Relative position buckets for relative position
bias.
max_rel_pos: Maximum relative position for relative position bias.
num_register_tokens: The number of register tokens.
scaled_emb: Whether to scale the sinusoidal positional embeddings.
n_features: The number of features in the input data.
"""
super().__init__(seq_length, hidden_dim)
self.fourier_ext = FourierEncoder(
seq_length=seq_length,
mlp_dim=mlp_dim,
output_dim=hidden_dim,
scaled=scaled_emb,
n_features=n_features,
)
self.pos_embedding = nn.Parameter(
torch.empty(1, seq_length, hidden_dim).normal_(std=0.02),
requires_grad=True,
)

self.class_token = nn.Parameter(
torch.empty(1, 1, hidden_dim),
requires_grad=True,
)
self.register_tokens = nn.Parameter(
torch.empty(1, num_register_tokens, hidden_dim),
requires_grad=True,
)

encoder_config = EncoderConfig(
encoder_attention_heads=num_heads,
encoder_embed_dim=hidden_dim,
encoder_ffn_embed_dim=mlp_dim,
encoder_layers=num_layers,
rel_pos_buckets=rel_pos_buckets,
max_rel_pos=max_rel_pos,
)
self.encoder = Encoder(encoder_config)

self.layer_norm = nn.LayerNorm(hidden_dim)

def forward(self, data: Data) -> Tensor:
"""Apply learnable forward pass."""
x, _, _ = array_to_sequence(data.x, data.batch, padding_value=0)
x = self.fourier_ext(x)
batch_size = x.shape[0]

x += self.pos_embedding

batch_class_token = self.class_token.expand(batch_size, -1, -1)
batch_register_tokens = self.register_tokens.expand(batch_size, -1, -1)
x = torch.cat([batch_class_token, batch_register_tokens, x], dim=1)

x = self.encoder(src_tokens=None, token_embeddings=x)
x = x["encoder_out"]

x = self.layer_norm(x)

return x[:, 0]

0 comments on commit 23581d8

Please sign in to comment.