-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a decoding mask option to only include subset of grid nodes in m2g #34
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Agree that this is required. Testing for "less than" seems sufficient 👍
More general comment: I am slightly worried about the mask being generated in mllam-data-prep
and then being used here and in neural-lam
. We probably need some solid checks in neural-lam
about the dimensions of the graph, the data-tensors and the mask.
…args in archetypes
There is one small problem currently, which is one reason why a6b6137 was introduced. When using the decoding mask (maybe this could happen also without it, not sure?), not all mesh nodes might be connected to the grid in m2g, i.e. there are mesh nodes that are not endpoints of any edges in m2g. This is desirable. However, when you split up a graph using
only nodes that are endpoint of any edge will be included. Node subsetting is only implicit through the edge sub-graphing: weather-model-graphs/src/weather_model_graphs/networkx_utils.py Lines 80 to 88 in a6d43e3
This means that something like weather-model-graphs/tests/test_save.py Line 30 in a6d43e3
This is why the option to return separate components was introduced, since it is unnecessary to merge everything and then split it up again (especially as splitting now breaks the node indices). Do you have any thoughts about good way to fix this @leifdenby ? Or do we want to fix it? Not sure if if has to be fixed in this PR, since in a sense it's a more general issue with the edge-attribute-based sub-graphing. |
Now all changes from #32 is merged into main and here so this diff is readable. |
Ok, that sounds like an issue. I don't fully understand why yet, but I believe you. So rather than merging to create a graph that represents the whole encode-process-decode graph of nodes and edges, you introduced
The reason why I implemented the graph creation by including a merging to one big graph is so that each node to could be given a globally (as in across the whole graph) unique ID. I then assumed that this id could then be used for constructing the adjacency matrices for different components (g2m, m2m, m2g say) which are then saved to individual files. Are you saying that this of having a global index doesn't work when masking out which grid-nodes are used in the I might have to create a little toy example to understand the issue properly |
Maybe if you have time you could add a test/notebook @joeloskarsson with an assert that checks for the indexing that you expect? I assume it is the I tried creating a notebook just now, but I am not quite sure about what the correct answer would be... |
Yes, the issue does appear when you split the combined graph. But the problem is not that the indices are wrong, it is that you lose some nodes (which make the indices wrong on those remaining nodes). So your toy example could be:
|
I have dug a bit deeper @joeloskarsson and I have realised where the issue is. I was under the impression that What I still don't quite understand is what neural-lam assumes about the grid-index values for different parts of the whole encode-process-decode graph. I think (based on the fact that we're having this discussion) that I have added a notebook with my experimentation here: https://github.com/leifdenby/weather-model-graphs/blob/decoding_mask/docs/decoding_mask.ipynb, and a started on a test for the decoding mask too: https://github.com/leifdenby/weather-model-graphs/blob/decoding_mask/tests/test_graph_decode_gridpoints_mask.py#L17 |
Yes, this is exactly what I also realized. I wish I would have had a bit more time to write out an explanation for this and I could have saved you the work 😓 Well well, I hope you feel like you gained some insights on the way. In my view the arbitraryness of indexing when converting networkx tto pyg feels really bad and not very thought through. But I guess this is all just because networkx doesn't really have a notion of an integer node index (closer to how we think of a graph in theory, with sets, bad for practical work imo 😛).
I am not sure if that is actually necessary, but we could do that. My implementation using this in neural-lam makes the hard decision that the nodes that you want to decode to have the first indices, follows by the nodes only used in g2m. This does mean that it doesn't matter if you include the masked out node indices in your m2g graph (you'll get the same edge_index tensor in the end anyway). On a more conceptual level, one could either argue that m2g should only index into the set of nodes we decode to, or argue that m2g should index into the set of all grid nodes, but only contain edges to the unmasked ones. Not obvious to me what makes more sense.
Yea, there is quite a lot to this. I have made very minimal assumptions wrt node indexing in neural-lam. Given that I know that we have some different perspectives on the graph (I have this perspective of separate graphs, rather than one joined), I feel like it is something that should be explained better. Maybe I could try to write up some explanation for what neural-lam expects (in that is mllam/neural-lam#93) when I find some time, or we could sit down and talk it through. |
Thanks @joeloskarsson ! I am glad we finally found out where the indexing issue is arising from 😅 @SimonKamuk had the idea that if we want to retain the idea of building a single networkx.DiGraph object and then splitting it, what we could do is to add an option when splitting this "global" graph that indicates that the user wishes to retain in every subgraph all the nodes, but only the edges that have a specific value for an attribute. This would mean that all subgraphs that are created would share the same indexing when converting using pyg, but of course each subgraph will also contain a lot of nodes that are not connected by any edges. Would this be ok? I made the plot I've included here when I was trying to work out what you store in the pickled tensor files. And at least wrt the edge index that appears to be a globally unique integer for each node. So maybe this would be an ok approach? One thing I'm worried about is that this would cause the I started work on the node retention idea in https://github.com/leifdenby/weather-model-graphs/blob/feat/node-retention-with-graph-splitting/src/weather_model_graphs/networkx_utils.py#L51, but then I realised that it would be better to check with you. |
I think what you describe would indeed work, and solve the problem. But would have to test with neural-lam to be sure. And yes, we would have to then rework what is saved in each I think however that the easier fix is to rework
to somehow not drop the mesh nodes that are not connected to the grid. Or rather, to split on something else than edge attributes then. One option could be to introduce some sort of node attribute that one could split based on, telling which sub-graphs each node is part of. That gets slightly complicated though, as nodes can belong to multiple subgraphs (e.g. g2m and m2g). Then there is of course always the options to not merge all graphs just to then split them again ;) (as I implemented here) Regarding this PR though: Are you ok with merging this with the |
Ok, great!
Good point. I think we should merge this PR with the functionality you have proposed. The functionality added here makes sense and we should keep moving. I will give a final review of the code as-is so we can get this in. (In general wrt how the graph is stored to disk from |
The only thing that I think is missing here is a documentation notebook showing how to use this. I have not got around to making that yet. But maybe you could just turn the notebook you added into that? |
I think the notebook you have added is already fine. Or by "this" do you not mean the decoding mask functionality or is it something else? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I've just added a few suggestions
projection: cartopy.crs.CRS or None | ||
Projection instance used to transform given lat-lon coords to in-projection | ||
Cartesian coordinates. If None the coords are assumed to already be Cartesian. | ||
decode_mask: Iterable or None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to make decode_mask
an explicit argument rather than in **kwargs
(I think that's where it currently supposed to be handled, right?)
Also in general, maybe we could use Iterable[int]
in type hints to make it clear that the mask is represented by a set of int
s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to keep the explicit arguments here. This would mean that this file is mostly unchanged, although there was a missing docstring for max_num_levels
, in create_oskarsson_hierarchical_graph()
that you have added in this PR
@@ -39,6 +41,8 @@ def create_all_graph_components( | |||
g2m_connectivity_kwargs={}, | |||
coords_crs: pyproj.crs.CRS | None = None, | |||
graph_crs: pyproj.crs.CRS | None = None, | |||
decode_mask: Iterable | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this decode_mask
argument into the m2g_connectivity
arguments I think and then masking could in fact be done within connect_nodes_across_graphs()
since the masking effects the way nodes are connected between graphs. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could for example introduce arguments called source_nodes_mask
and target_nodes_mask
to connect_nodes_across_graphs
(https://github.com/mllam/weather-model-graphs/blob/main/src/weather_model_graphs/create/base.py#L181)
@@ -162,6 +181,14 @@ def create_all_graph_components( | |||
for edge in graph.edges: | |||
graph.edges[edge]["component"] = name | |||
|
|||
if return_components: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment here, something like:
if return_components: | |
if return_components: | |
# Because merging to a single graph and then splitting again leads to changes in node indexing when converting to `pyg.Data` objects (this in part is due to the to `m2g` and `g2m` having a different set of grid nodes) the ability to return the graph components (`g2m`, `m2m` and `m2g`) has been added here. See https://github.com/mllam/weather-model-graphs/pull/34#issuecomment-2507980752 for details |
@@ -98,40 +98,29 @@ def split_graph_by_edge_attribute(graph, attr): | |||
f"No subgraphs were created. Check the edge attribute '{attr}'." | |||
) | |||
|
|||
# copy node attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, what a mess 😮 I don't know what I did here... Good you are removing these extra lines
def sort_nodes_in_graph(graph): | ||
""" | ||
Creates a new networkx.DiGraph that is a copy of input, but with nodes | ||
sorted according to their id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorted according to their id | |
sorted according to their label value |
I think strictly speaking nodes are identified by something called their "label" rather than "id" (e.g. https://networkx.org/documentation/stable/reference/generated/networkx.relabel.convert_node_labels_to_integers.html#networkx.relabel.convert_node_labels_to_integers). And in our case we use integers as labels
|
||
return subgraphs | ||
return sorted_graph | ||
|
||
|
||
def replace_node_labels_with_unique_ids(graph): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am realising now that this is a poor name, maybe I should have called it replace_node_labels_with_unique_integer_values
Describe your changes
For LAM forecasting it is reasonable to not always predict values in nodes that are considered part of the boundary forcing. This is in particular a change planned for neural-lam. When we consider the graphs involved, this means that the g2m edges should only connect to a subset of the grid nodes.
This PR introduces an option
decode_mask
that allows for specifying an Iterable of booleans (e.g. a numpy-array) specifying which of the grid nodes should be included in the decoding-part of the graph (m2g). This allows in the LAM case to specify such a mask with True for the inner region nodes.This builds on #32, which should be merged first. Here is a diff for only this PR in the meantime: joeloskarsson/weather-model-graphs@general_coordinates...decoding_mask
Type of change
Checklist before requesting a review
pull
with--rebase
option if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee