Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a decoding mask option to only include subset of grid nodes in m2g #34

Open
wants to merge 50 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
8a11527
Start looking into where xy has to be changed
joeloskarsson Oct 2, 2024
2c7e4eb
Change shapes in docstrings
joeloskarsson Oct 3, 2024
6f03908
Make flat mesh graphs work with new coordinate layout
joeloskarsson Oct 3, 2024
d2a081a
Merge PR #26 into branch
joeloskarsson Oct 13, 2024
9dbd2a5
Fix coordinate handling in multirange graph creation
joeloskarsson Oct 13, 2024
eada6ea
Rename grid_refinement_factor to mesh_node_distance
joeloskarsson Oct 13, 2024
01180a1
Fix existing tests to work with new coordinate format
joeloskarsson Oct 13, 2024
d103ad6
Add test for irregularlygridded coordinates
joeloskarsson Oct 13, 2024
7ba387e
Remove unneeded eps in mesh level calculation
joeloskarsson Oct 14, 2024
9069c84
Change documentation to use new format and arguments for coordinates
joeloskarsson Oct 14, 2024
d592c2b
Fix bug in coordinate order for flat graphs
joeloskarsson Oct 14, 2024
fd695b1
Start working on allowing latlon coordinates
joeloskarsson Oct 14, 2024
2c781ee
Introduce coords and projection
joeloskarsson Oct 14, 2024
5cf3bbf
Merge branch 'main' into general_coordinates
joeloskarsson Oct 14, 2024
0f17d2c
Fix linting
joeloskarsson Oct 14, 2024
f145025
Fix tests with coords keyword argument
joeloskarsson Oct 14, 2024
cc4cc5e
Implement lat-lon transformation through projection
joeloskarsson Oct 16, 2024
2627e37
Add documentation page about graphs constructed using lat-lons
joeloskarsson Oct 16, 2024
c764fd7
Adjust coords keyword arg in docs
joeloskarsson Oct 16, 2024
f6ae35b
Add test for lat-lon coordinates
joeloskarsson Oct 16, 2024
22caf65
Fix linting of docs
joeloskarsson Oct 16, 2024
7ec34ff
Merge main into branch
joeloskarsson Oct 17, 2024
f564c70
Add decode_mask for only including subset of grid nodes in m2g
joeloskarsson Oct 21, 2024
2333d53
Add test for decode filtering
joeloskarsson Oct 21, 2024
3ee25ae
Fix typos and clarifications as suggested from code review
joeloskarsson Oct 22, 2024
70eef3e
Change euclidean coordinates to Cartesian coordinates
joeloskarsson Oct 23, 2024
3c4866b
Merge branch 'main' into general_coordinates
joeloskarsson Oct 23, 2024
5f33bc5
Merge branch 'general_coordinates' into decoding_mask
joeloskarsson Oct 23, 2024
63df482
Merge branch 'main' into decoding_mask
joeloskarsson Nov 11, 2024
350e0c0
Sort nodes and subgraphs for saving
joeloskarsson Nov 11, 2024
806b78f
Fix linting
joeloskarsson Nov 11, 2024
4fb69dc
Merge branch 'main' into general_coordinates
joeloskarsson Nov 18, 2024
8fcf182
Apply suggested documentation and code readability updates
joeloskarsson Nov 19, 2024
746966f
Update src/weather_model_graphs/create/mesh/kinds/hierarchical.py
joeloskarsson Nov 26, 2024
47b5dcc
Clarify comments and variable names around mesh level computation
joeloskarsson Nov 26, 2024
05a0cc6
Add check for number of nodes in test with irregular coords
joeloskarsson Nov 26, 2024
339feff
Update docs line on square meshes
joeloskarsson Nov 26, 2024
6b281bf
Reference lat-lon notebook in coordinate section
joeloskarsson Nov 26, 2024
a525818
Change projection spec to use pyproj crs:s
joeloskarsson Nov 26, 2024
8733359
Adjust test to crs arguments
joeloskarsson Nov 26, 2024
c95f023
Fix linting
joeloskarsson Nov 26, 2024
bcaf1e1
Update docs to crs change
joeloskarsson Nov 26, 2024
8e3c1cb
Add cartopy dependency to visualization group
joeloskarsson Nov 26, 2024
541054e
Merge branch 'general_coordinates' into decoding_mask
joeloskarsson Nov 26, 2024
169ea48
Sort nodes by id before pyg conversion
joeloskarsson Nov 27, 2024
a6b6137
Introduce option to return graph components directly, used through kw…
joeloskarsson Nov 27, 2024
7576b43
Merge branch 'main' into decoding_mask
joeloskarsson Nov 29, 2024
ae90eb8
Add explanation of **kwargs to archetype docstrings
joeloskarsson Nov 29, 2024
ee8601e
add test ensuring unchanged grid-indecies /w decode mask
leifdenby Dec 6, 2024
e178a0b
add test and example notebook
leifdenby Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
473 changes: 473 additions & 0 deletions docs/decoding_mask.ipynb

Large diffs are not rendered by default.

45 changes: 14 additions & 31 deletions src/weather_model_graphs/create/archetype.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to keep the explicit arguments here. This would mean that this file is mostly unchanged, although there was a missing docstring for max_num_levels, in create_oskarsson_hierarchical_graph() that you have added in this PR

Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from .base import create_all_graph_components


def create_keisler_graph(
coords,
mesh_node_distance=3,
coords_crs=None,
graph_crs=None,
):
def create_keisler_graph(coords, mesh_node_distance=3, **kwargs):
"""
Create a flat LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
This graph setup is inspired by the global graph used by Keisler (2022, https://arxiv.org/abs/2202.07575).
Expand All @@ -28,11 +23,8 @@ def create_keisler_graph(
mesh_node_distance: float
Distance (in x- and y-direction) between created mesh nodes,
in coordinate system of coords
coords_crs: pyproj.crs.CRS or None
CRS of the given coordinates
graph_crs:
CRS to build graph in. If given, coords will be transformed from
coords_crs to graph_crs before graph construction
**kwargs:
Additional keyword arguments passed on to create_all_graph_components.

Returns
-------
Expand All @@ -51,8 +43,7 @@ def create_keisler_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
coords_crs=coords_crs,
graph_crs=graph_crs,
**kwargs,
)


Expand All @@ -61,8 +52,7 @@ def create_graphcast_graph(
mesh_node_distance=3,
level_refinement_factor=3,
max_num_levels=None,
coords_crs=None,
graph_crs=None,
**kwargs,
):
"""
Create a multiscale LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
Expand All @@ -89,11 +79,8 @@ def create_graphcast_graph(
NOTE: Must be an odd integer >1 to create proper multiscale graph
max_num_levels: int
The number of levels of longer-range connections in the mesh graph.
coords_crs: pyproj.crs.CRS or None
CRS of the given coordinates
graph_crs:
CRS to build graph in. If given, coords will be transformed from
coords_crs to graph_crs before graph construction
**kwargs:
Additional keyword arguments passed on to create_all_graph_components.

Returns
-------
Expand All @@ -116,8 +103,7 @@ def create_graphcast_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
coords_crs=coords_crs,
graph_crs=graph_crs,
**kwargs,
)


Expand All @@ -126,8 +112,7 @@ def create_oskarsson_hierarchical_graph(
mesh_node_distance=3,
level_refinement_factor=3,
max_num_levels=None,
coords_crs=None,
graph_crs=None,
**kwargs,
):
"""
Create a LAM graph following Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
Expand Down Expand Up @@ -157,11 +142,10 @@ def create_oskarsson_hierarchical_graph(
in coordinate system of coords
level_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
coords_crs: pyproj.crs.CRS or None
CRS of the given coordinates
graph_crs:
CRS to build graph in. If given, coords will be transformed from
coords_crs to graph_crs before graph construction
max_num_levels: int
The number of levels of longer-range connections in the mesh graph.
**kwargs:
Additional keyword arguments passed on to create_all_graph_components.

Returns
-------
Expand All @@ -184,6 +168,5 @@ def create_oskarsson_hierarchical_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
coords_crs=coords_crs,
graph_crs=graph_crs,
**kwargs,
)
29 changes: 28 additions & 1 deletion src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"""


from typing import Iterable

import networkx
import networkx as nx
import numpy as np
Expand Down Expand Up @@ -39,6 +41,8 @@ def create_all_graph_components(
g2m_connectivity_kwargs={},
coords_crs: pyproj.crs.CRS | None = None,
graph_crs: pyproj.crs.CRS | None = None,
decode_mask: Iterable | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this decode_mask argument into the m2g_connectivity arguments I think and then masking could in fact be done within connect_nodes_across_graphs() since the masking effects the way nodes are connected between graphs. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could for example introduce arguments called source_nodes_mask and target_nodes_mask to connect_nodes_across_graphs (https://github.com/mllam/weather-model-graphs/blob/main/src/weather_model_graphs/create/base.py#L181)

return_components: bool = False,
):
"""
Create all graph components used in creating the message-passing graph,
Expand Down Expand Up @@ -82,6 +86,11 @@ def create_all_graph_components(
will be transformed from their original Coordinate Reference System (`coords_crs`)
to the CRS where the graph creation should take place (`graph_crs`).
If any one of them is None the graph creation is carried out using the original coords.

`decode_mask` should be an Iterable of booleans, masking which grid positions should be
decoded to (included in the m2g subgraph), i.e. which positions should be output. It should have the same length as the number of
grid position coordinates given in `coords`. The mask being set to True means that corresponding
grid nodes should be included in g2m. If `decode_mask=None` (default), all grid nodes are included.
"""
graph_components: dict[networkx.DiGraph] = {}

Expand Down Expand Up @@ -149,9 +158,19 @@ def create_all_graph_components(
)
graph_components["g2m"] = G_g2m

if decode_mask is None:
# decode to all grid nodes
decode_grid = G_grid
else:
# Select subset of grid nodes to decode to, where m2g should connect
filter_nodes = [
n for n, include in zip(G_grid.nodes, decode_mask, strict=True) if include
]
decode_grid = G_grid.subgraph(filter_nodes)

G_m2g = connect_nodes_across_graphs(
G_source=grid_connect_graph,
G_target=G_grid,
G_target=decode_grid,
method=m2g_connectivity,
**m2g_connectivity_kwargs,
)
Expand All @@ -162,6 +181,14 @@ def create_all_graph_components(
for edge in graph.edges:
graph.edges[edge]["component"] = name

if return_components:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here, something like:

Suggested change
if return_components:
if return_components:
# Because merging to a single graph and then splitting again leads to changes in node indexing when converting to `pyg.Data` objects (this in part is due to the to `m2g` and `g2m` having a different set of grid nodes) the ability to return the graph components (`g2m`, `m2m` and `m2g`) has been added here. See https://github.com/mllam/weather-model-graphs/pull/34#issuecomment-2507980752 for details

# Give each component unique ids
graph_components = {
comp_name: replace_node_labels_with_unique_ids(subgraph)
for comp_name, subgraph in graph_components.items()
}
return graph_components

# merge to single graph
G_tot = networkx.compose_all(graph_components.values())
# only keep graph attributes that are the same for all components
Expand Down
49 changes: 19 additions & 30 deletions src/weather_model_graphs/networkx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,40 +98,29 @@ def split_graph_by_edge_attribute(graph, attr):
f"No subgraphs were created. Check the edge attribute '{attr}'."
)

# copy node attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, what a mess 😮 I don't know what I did here... Good you are removing these extra lines

for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attr}'."
)
return subgraphs

# copy node attributes
for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attr}'."
)
def sort_nodes_in_graph(graph):
"""
Creates a new networkx.DiGraph that is a copy of input, but with nodes
sorted according to their id
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sorted according to their id
sorted according to their label value

I think strictly speaking nodes are identified by something called their "label" rather than "id" (e.g. https://networkx.org/documentation/stable/reference/generated/networkx.relabel.convert_node_labels_to_integers.html#networkx.relabel.convert_node_labels_to_integers). And in our case we use integers as labels


# copy node attributes
for subgraph in subgraphs.values():
for node in subgraph.nodes:
subgraph.nodes[node].update(graph.nodes[node])
Parameters
----------
graph : networkx.DiGraph
Graph to sort nodes from

# check that at least one subgraph was created
if len(subgraphs) == 0:
raise ValueError(
f"No subgraphs were created. Check the edge attribute '{attr}'."
)
Returns
-------
networkx.DiGraph
Graph with sorted nodes
"""
sorted_graph = networkx.DiGraph()
sorted_graph.add_nodes_from(sorted(graph.nodes(data=True)))
sorted_graph.add_edges_from(graph.edges(data=True))

return subgraphs
return sorted_graph


def replace_node_labels_with_unique_ids(graph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am realising now that this is a poor name, maybe I should have called it replace_node_labels_with_unique_integer_values

Expand All @@ -149,7 +138,7 @@ def replace_node_labels_with_unique_ids(graph):
Graph with node labels renamed
"""
return networkx.relabel_nodes(
graph, {node: i for i, node in enumerate(graph.nodes)}, copy=True
graph, {node: i for i, node in enumerate(sorted(graph.nodes))}, copy=True
)


Expand Down
27 changes: 19 additions & 8 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import networkx
from loguru import logger

from .networkx_utils import MissingEdgeAttributeError, split_graph_by_edge_attribute
from .networkx_utils import (
MissingEdgeAttributeError,
sort_nodes_in_graph,
split_graph_by_edge_attribute,
)

try:
import torch
Expand Down Expand Up @@ -106,18 +110,25 @@ def _concat_pyg_features(
if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
try:
sub_graphs = list(
split_graph_by_edge_attribute(
graph=graph, attr=list_from_attribute
).values()
)
sub_graphs = [
value
for key, value in sorted(
split_graph_by_edge_attribute(
graph=graph, attr=list_from_attribute
).items()
)
]
except MissingEdgeAttributeError:
# neural-lam still expects a list of graphs, so if the attribute is missing
# we just return the original graph as a list
sub_graphs = [graph]
pyg_graphs = [pyg_convert.from_networkx(g) for g in sub_graphs]
# Nodes must be sorted if we want to preserve any ordering
# when converted to pyg
pyg_graphs = [
pyg_convert.from_networkx(sort_nodes_in_graph(g)) for g in sub_graphs
]
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]
pyg_graphs = [pyg_convert.from_networkx(sort_nodes_in_graph(graph))]

edge_features_values = [
_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs
Expand Down
3 changes: 2 additions & 1 deletion src/weather_model_graphs/visualise/plot_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def nx_draw_with_pos_and_attr(
node_zorder_attr=None,
node_size=100,
connectionstyle="arc3, rad=0.1",
with_labels=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -171,7 +172,7 @@ def nx_draw_with_pos_and_attr(
graph,
ax=ax,
arrows=True,
with_labels=False,
with_labels=with_labels,
node_size=node_size,
connectionstyle=connectionstyle,
**kwargs,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,27 @@ def test_create_lat_lon(kind):
coords_crs=coords_crs,
graph_crs=graph_crs,
)


@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])
def test_create_decode_mask(kind):
"""
Tests that the decode mask for m2g works, resulting in less edges than
no filtering.
"""
xy = test_utils.create_fake_irregular_coords(100)
fn_name = f"create_{kind}_graph"
fn = getattr(wmg.create.archetype, fn_name)
# ~= 20 mesh nodes in bottom layer in each direction
mesh_node_distance = 0.05

unfiltered_graph = fn(coords=xy, mesh_node_distance=mesh_node_distance)

# Filter to only 20 / 100 grid nodes
decode_mask = np.concatenate((np.ones(20), np.zeros(80))).astype(bool)
filtered_graph = fn(
coords=xy, mesh_node_distance=mesh_node_distance, decode_mask=decode_mask
)

# Check that some filtering has been performed
assert len(filtered_graph.edges) < len(unfiltered_graph.edges)
Loading
Loading