Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
rollback register_latlon()
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jul 5, 2024
1 parent ad4e0a9 commit 127e535
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(
self._create_trainable_attributes()

# Register lat/lon of nodes
for nodes_name in self._graph_data.node_types:
self._register_latlon(nodes_name)
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.num_channels = config.model.num_channels

Expand Down Expand Up @@ -126,17 +126,17 @@ def _define_tensor_sizes(self, config: DotDict) -> None:
self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _register_latlon(self, nodes_name: str) -> None:
def _register_latlon(self, name: str, nodes: str) -> None:
"""Register lat/lon buffers.
Parameters
----------
nodes_name : str
Name of nodes to map
"""
coords = self._graph_data[nodes_name].x
coords = self._graph_data[nodes].x
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
self.register_buffer(f"latlons_{nodes_name}", sin_cos_coords, persistent=True)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
Expand Down

0 comments on commit 127e535

Please sign in to comment.