Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Provide positional embedding on the hidden grid for transformer #96

Draft
wants to merge 8 commits into
base: develop
Choose a base branch
from
12 changes: 12 additions & 0 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
num_heads: int,
activation: str,
window_size: int,
positional_encoding_hidden: Optional[Tensor] = None,
dropout_p: float = 0.0,
):
super().__init__()
Expand All @@ -80,6 +81,12 @@ def __init__(

self.layer_norm1 = nn.LayerNorm(num_channels)

self.register_buffer("positional_encoding_hidden", positional_encoding_hidden)
if self.positional_encoding_hidden is not None:
self.pos_embedder = nn.Linear(
self.positional_encoding_hidden.shape[-1], num_channels
) # assuming that we have 3 position features, lat and cos / sin of lon

self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
embed_dim=num_channels,
Expand All @@ -99,6 +106,11 @@ def __init__(
def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
if self.positional_encoding_hidden is not None:
pos_embedding = self.pos_embedder(self.positional_encoding_hidden)
pos_embedding = pos_embedding.repeat(batch_size, 1)
x = x + pos_embedding

# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm2(x))
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
positional_encoding_hidden: Optional[Tensor] = None,
dropout_p: float = 0.0,
) -> None:
"""Initialize TransformerProcessor.
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
positional_encoding_hidden=positional_encoding_hidden,
dropout_p=dropout_p,
)

Expand Down
59 changes: 59 additions & 0 deletions src/anemoi/models/layers/positionalencoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABC
from abc import abstractmethod

import torch
from torch import Tensor


class BasePositionalEncoding(ABC):
"""Configurable method calcuating positional encodings for latlons of a grid.

To enable the positional encoding add the following to the model-config file and
chose the corresponding positional-encoding-class:
```
positional_encoding:
_target_: anemoi.models.layers.positionalencoding.CosSinLatCosSinLon
_convert_: all
```
If the entry positional_encoding does not exist or is None, no positional encoding is used.

"""

def __init__(self) -> None:
"""Initialise Function for calculating the positional encodings."""

@abstractmethod
def positional_encoding(self, latlons_hidden: Tensor) -> Tensor: ...


class LatCosSinLon(BasePositionalEncoding):
"""Lat, cos(lon), sin(lon) for grid points."""

def positional_encoding(self, latlons_hidden: Tensor) -> Tensor:
"""Output lat, cos(lon), sin(lon) for grid points."""
lat_coslon_sinlon_hidden = torch.cat(
(
latlons_hidden[:, 0].unsqueeze(-1),
torch.cos(latlons_hidden[:, 1].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 1].unsqueeze(-1)),
),
dim=-1,
)
return lat_coslon_sinlon_hidden


class CosSinLatCosSinLon(BasePositionalEncoding):
"""Cos(lat), sin(lat), cos(lon), sin(lon) for grid points."""

def positional_encoding(self, latlons_hidden: Tensor) -> Tensor:
"""Output cos(lat), sin(lat), cos(lon), sin(lon) for grid points."""
coslat_sinlat_coslon_sinlon_hidden = torch.cat(
(
torch.cos(latlons_hidden[:, 0].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 0].unsqueeze(-1)),
torch.cos(latlons_hidden[:, 1].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 1].unsqueeze(-1)),
),
dim=-1,
)
return coslat_sinlat_coslon_sinlon_hidden
5 changes: 5 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
num_chunks: int = 2,
activation: str = "GELU",
cpu_offload: bool = False,
positional_encoding_hidden: Optional[Tensor] = None,
**kwargs,
) -> None:
"""Initialize BaseProcessor."""
Expand All @@ -49,6 +50,7 @@ def __init__(
self.num_chunks = num_chunks
self.num_channels = num_channels
self.chunk_size = num_layers // num_chunks
self.positional_encoding_hidden = positional_encoding_hidden

assert (
num_layers % num_chunks == 0
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
num_chunks: int = 2,
activation: str = "GELU",
cpu_offload: bool = False,
positional_encoding_hidden: Optional[Tensor] = None,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
num_chunks=num_chunks,
activation=activation,
cpu_offload=cpu_offload,
positional_encoding_hidden=positional_encoding_hidden,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
)
Expand All @@ -137,6 +141,7 @@ def __init__(
num_layers=self.chunk_size,
window_size=window_size,
activation=activation,
positional_encoding_hidden=positional_encoding_hidden,
dropout_p=dropout_p,
)

Expand Down
11 changes: 11 additions & 0 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,22 @@ def __init__(
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

positional_encoding_hidden = None
if model_config.model.get("positional_encoding") is not None:
LOGGER.info(
"Using positional encoding. Target function: %s", model_config.model.positional_encoding._target_
)
self.positional_encoding = instantiate(model_config.model.positional_encoding)
positional_encoding_hidden = self.positional_encoding.positional_encoding(
self.node_attributes.get_coordinates(self._graph_name_hidden)
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
positional_encoding_hidden=positional_encoding_hidden,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)
Expand Down
Loading