Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #84 from ecmwf/feature/reduce-decoder-mem-usage
Browse files Browse the repository at this point in the history
Feature/reduce-decoder-mem-usage
  • Loading branch information
mchantry authored Dec 13, 2024
2 parents f58124e + 0fb033a commit 90ac4df
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Keep it human-readable, your future self will thank you!
- Update copyright notice
- Fix `__version__` import in init
- Fix missing copyrights [#71](https://github.com/ecmwf/anemoi-models/pull/71)
- Reduced memory usage when using chunking in the mapper [#84](https://github.com/ecmwf/anemoi-models/pull/84)

### Removed

Expand Down
6 changes: 2 additions & 4 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,18 +512,16 @@ def forward(
edge_attr_list, edge_index_list = sort_edges_1hop_chunks(
num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks
)
out = torch.zeros((x[1].shape[0], self.num_heads, self.out_channels_conv), device=x[1].device)
for i in range(num_chunks):
out1 = self.conv(
out += self.conv(
query=query,
key=key,
value=value,
edge_attr=edge_attr_list[i],
edge_index=edge_index_list[i],
size=size,
)
if i == 0:
out = torch.zeros_like(out1, device=out1.device)
out = out + out1
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

Expand Down

0 comments on commit 90ac4df

Please sign in to comment.