Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
positional-embedding-hidden-grid
Browse files Browse the repository at this point in the history
  • Loading branch information
sahahner committed Dec 2, 2024
1 parent fd2bcf1 commit bbba4f6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
num_heads: int,
activation: str,
window_size: int,
grid_lat_coslon_sinlon: Tensor = None,
dropout_p: float = 0.0,
):
super().__init__()
Expand All @@ -80,6 +81,11 @@ def __init__(

self.layer_norm1 = nn.LayerNorm(num_channels)

self.grid_lat_coslon_sinlon = grid_lat_coslon_sinlon
if self.grid_lat_coslon_sinlon is not None:
self.grid_lat_coslon_sinlon = self.grid_lat_coslon_sinlon

This comment has been minimized.

Copy link
@ssmmnn11

ssmmnn11 Dec 4, 2024

Member

I guess here you can leave out the part

if grid_lat_coslon_sinlon is not None:
self.grid_lat_coslon_sinlon = grid_lat_coslon_sinlon

because you have already assigned it above.

I think it would be good to make it a buffer so that it is moved to the GPU together with the model. Currently you copy it from CPU to GPU in each forward

if self.grid_lat_coslon_sinlon is not None:
pos_embedding = self.pos_embedder(self.grid_lat_coslon_sinlon.to(x.device))
pos_embedding = pos_embedding.repeat(batch_size, 1)

This comment has been minimized.

Copy link
@sahahner

sahahner Dec 5, 2024

Author Member

Thank you for these comments. The issues should be resolved by the latest commit.

self.pos_embedder = nn.Linear(3, num_channels) # assuming that we have 3 position features, lat and cos / sin of lon

self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
embed_dim=num_channels,
Expand All @@ -99,6 +105,11 @@ def __init__(
def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
if self.grid_lat_coslon_sinlon is not None:
pos_embedding = self.pos_embedder(self.grid_lat_coslon_sinlon.to(x.device))
pos_embedding = pos_embedding.repeat(batch_size, 1)
x = x + pos_embedding

# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group)
x = x + self.mlp(self.layer_norm2(x))
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
grid_lat_coslon_sinlon: Tensor = None,
dropout_p: float = 0.0,
) -> None:
"""Initialize TransformerProcessor.
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
grid_lat_coslon_sinlon=grid_lat_coslon_sinlon,
dropout_p=dropout_p,
)

Expand Down
5 changes: 5 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
num_chunks: int = 2,
activation: str = "GELU",
cpu_offload: bool = False,
grid_lat_coslon_sinlon: Tensor = None,
**kwargs,
) -> None:
"""Initialize BaseProcessor."""
Expand All @@ -49,6 +50,7 @@ def __init__(
self.num_chunks = num_chunks
self.num_channels = num_channels
self.chunk_size = num_layers // num_chunks
self.grid_lat_coslon_sinlon = grid_lat_coslon_sinlon

assert (
num_layers % num_chunks == 0
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
num_chunks: int = 2,
activation: str = "GELU",
cpu_offload: bool = False,
grid_lat_coslon_sinlon: Tensor = None,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
num_chunks=num_chunks,
activation=activation,
cpu_offload=cpu_offload,
grid_lat_coslon_sinlon=grid_lat_coslon_sinlon,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
)
Expand All @@ -137,6 +141,7 @@ def __init__(
num_layers=self.chunk_size,
window_size=window_size,
activation=activation,
grid_lat_coslon_sinlon=grid_lat_coslon_sinlon,
dropout_p=dropout_p,
)

Expand Down
10 changes: 10 additions & 0 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,21 @@ def __init__(
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

latlons_hidden = self.node_attributes.get_coordinates(self._graph_name_hidden)
lat_coslon_sinlon_hidden = torch.cat( # lat, cos(lon), sin(lon) for hidden grid points
( latlons_hidden[:, 0].unsqueeze(-1),
torch.cos(latlons_hidden[:, 1].unsqueeze(-1)),
torch.sin(latlons_hidden[:, 1].unsqueeze(-1)),
),
dim=-1,
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
grid_lat_coslon_sinlon = lat_coslon_sinlon_hidden,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)
Expand Down

0 comments on commit bbba4f6

Please sign in to comment.