Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 5, 2024
1 parent 8bc7d79 commit 5505e97
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 38 deletions.
6 changes: 3 additions & 3 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
else:
_FLASH_ATTENTION_AVAILABLE = True

from anemoi.utils.config import DotDict

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(
self.dropout_p = dropout_p
self.is_causal = is_causal

linear=layer_kernels["Linear"]
linear = layer_kernels["Linear"]
self.lin_qkv = linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

Expand Down
13 changes: 5 additions & 8 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import einops
import torch
from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
Expand All @@ -32,7 +33,6 @@
from anemoi.models.layers.conv import GraphConv
from anemoi.models.layers.conv import GraphTransformerConv
from anemoi.models.layers.mlp import MLP
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,9 +100,7 @@ def __init__(
)

def forward(
self, x: Tensor, shapes: list, batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
**kwargs
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None, **kwargs
) -> Tensor:
# Need to be out of place for gradient propagation
x = x + self.attention(self.layer_norm1(x), shapes, batch_size, model_comm_group=model_comm_group)
Expand Down Expand Up @@ -348,8 +346,8 @@ def __init__(

self.num_chunks = num_chunks

linear=layer_kernels['Linear']
layerNorm=layer_kernels['LayerNorm']
linear = layer_kernels["Linear"]
layerNorm = layer_kernels["LayerNorm"]
self.lin_key = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = linear(in_channels, num_heads * self.out_channels_conv)
Expand Down Expand Up @@ -627,8 +625,7 @@ def __init__(
bias=bias,
activation=activation,
num_chunks=num_chunks,
update_src_nodes=update_src_nodes
**kwargs,
update_src_nodes=update_src_nodes**kwargs,
)

def forward(
Expand Down
13 changes: 7 additions & 6 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import abstractmethod
from typing import Optional

from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
Expand All @@ -24,7 +25,6 @@
from anemoi.models.layers.block import GraphTransformerProcessorBlock
from anemoi.models.layers.block import TransformerProcessorBlock
from anemoi.models.layers.mlp import MLP
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,18 +111,19 @@ def __init__(
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
layer_kernels=layer_kernels
layer_kernels=layer_kernels,
)

def forward(
self, x: Tensor, shapes: list, batch_size: int,
self,
x: Tensor,
shapes: list,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
**kwargs,
) -> Tensor:
for i in range(self.num_layers):
x = self.blocks[i](x, shapes, batch_size,
model_comm_group=model_comm_group,
**kwargs)
x = self.blocks[i](x, shapes, batch_size, model_comm_group=model_comm_group, **kwargs)

return (x,) # return tuple for consistency with other processors

Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
Expand All @@ -31,7 +32,6 @@
from anemoi.models.layers.block import GraphTransformerMapperBlock
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.mlp import MLP
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -229,8 +229,8 @@ def __init__(
activation=activation,
layer_kernels=layer_kernels,
)
#Linear = layer_kernels.get("Linear", torch.nn.Linear)

# Linear = layer_kernels.get("Linear", torch.nn.Linear)
Linear = layer_kernels["Linear"]

self._register_edges(sub_graph, sub_graph_edge_attributes, src_grid_size, dst_grid_size, trainable_size)
Expand All @@ -245,7 +245,7 @@ def __init__(
edge_dim=self.edge_dim,
activation=activation,
num_chunks=num_chunks,
layer_kernels=layer_kernels
layer_kernels=layer_kernels,
)

self.offload_layers(cpu_offload)
Expand Down
5 changes: 2 additions & 3 deletions src/anemoi/models/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
import logging

import torch
from anemoi.utils.config import DotDict
from torch import nn

from anemoi.models.layers.normalization import AutocastLayerNorm
from anemoi.models.layers.utils import CheckpointWrapper
from anemoi.utils.config import DotDict

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(

Linear = layer_kernels["Linear"]
LayerNorm = layer_kernels["LayerNorm"]

try:
act_func = getattr(nn, activation)
except AttributeError as ae:
Expand Down
3 changes: 0 additions & 3 deletions src/anemoi/models/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from __future__ import annotations

from abc import ABC
from abc import abstractmethod

import torch
from torch import nn


Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from abc import ABC
from typing import Optional

from anemoi.utils.config import DotDict
from torch import Tensor
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
Expand All @@ -27,7 +28,6 @@
from anemoi.models.layers.chunk import TransformerProcessorChunk
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.mapper import GraphEdgeMixin
from anemoi.utils.config import DotDict


class BaseProcessor(nn.Module, ABC):
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
cpu_offload=cpu_offload,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
#layer_kernels=layer_kernels,
# layer_kernels=layer_kernels,
)

self.build_layers(
Expand Down
1 change: 0 additions & 1 deletion src/anemoi/models/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# nor does it submit to any jurisdiction.


from torch import Tensor
from torch import nn
from torch.utils.checkpoint import checkpoint

Expand Down
16 changes: 8 additions & 8 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import einops
import torch
from anemoi.utils.config import DotDict
from hydra.errors import InstantiationException
from hydra.utils import instantiate
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData
from hydra.errors import InstantiationException

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import NamedNodesAttributes
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)

input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]

# read config.model.layer_kernels to get the implementation for certain layers
self._load_layer_kernels(model_config)

Expand Down Expand Up @@ -242,13 +242,13 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
def _load_layer_kernels(self, config: DotDict) -> None:

# If self.layer_kernels entry is missing from the config, use torch.nn by default
default_kernels=DotDict()
default_kernels = DotDict()
default_kernels["Linear"] = DotDict({"_target_": "torch.nn.Linear", "_partial_": True})
default_kernels["LayerNorm"] = DotDict({"_target_": "torch.nn.LayerNorm", "_partial_": True})
#self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels...
self.layer_kernels= config.model.layer_kernels

# self.layer_kernels = config.get("model.layer_kernels", default_kernels) #Always uses default kernels...
self.layer_kernels = config.model.layer_kernels

# Loop through all kernels in the layer_kernels config entry and try import them
for kernel in self.layer_kernels:
kernel_entry = self.layer_kernels[kernel]
Expand All @@ -260,4 +260,4 @@ def _load_layer_kernels(self, config: DotDict) -> None:
)
raise InstantiationException
else:
LOGGER.info(f"{kernel} kernel: {kernel_entry}")
LOGGER.info(f"{kernel} kernel: {kernel_entry}")

0 comments on commit 5505e97

Please sign in to comment.