-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a decoding mask option to only include subset of grid nodes in m2g #34
base: main
Are you sure you want to change the base?
Changes from all commits
8a11527
2c7e4eb
6f03908
d2a081a
9dbd2a5
eada6ea
01180a1
d103ad6
7ba387e
9069c84
d592c2b
fd695b1
2c781ee
5cf3bbf
0f17d2c
f145025
cc4cc5e
2627e37
c764fd7
f6ae35b
22caf65
7ec34ff
f564c70
2333d53
3ee25ae
70eef3e
3c4866b
5f33bc5
63df482
350e0c0
806b78f
4fb69dc
8fcf182
746966f
47b5dcc
05a0cc6
339feff
6b281bf
a525818
8733359
c95f023
bcaf1e1
8e3c1cb
541054e
169ea48
a6b6137
7576b43
ae90eb8
ee8601e
e178a0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -9,6 +9,8 @@ | |||||||
""" | ||||||||
|
||||||||
|
||||||||
from typing import Iterable | ||||||||
|
||||||||
import networkx | ||||||||
import networkx as nx | ||||||||
import numpy as np | ||||||||
|
@@ -39,6 +41,8 @@ def create_all_graph_components( | |||||||
g2m_connectivity_kwargs={}, | ||||||||
coords_crs: pyproj.crs.CRS | None = None, | ||||||||
graph_crs: pyproj.crs.CRS | None = None, | ||||||||
decode_mask: Iterable | None = None, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would move this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could for example introduce arguments called |
||||||||
return_components: bool = False, | ||||||||
): | ||||||||
""" | ||||||||
Create all graph components used in creating the message-passing graph, | ||||||||
|
@@ -82,6 +86,11 @@ def create_all_graph_components( | |||||||
will be transformed from their original Coordinate Reference System (`coords_crs`) | ||||||||
to the CRS where the graph creation should take place (`graph_crs`). | ||||||||
If any one of them is None the graph creation is carried out using the original coords. | ||||||||
|
||||||||
`decode_mask` should be an Iterable of booleans, masking which grid positions should be | ||||||||
decoded to (included in the m2g subgraph), i.e. which positions should be output. It should have the same length as the number of | ||||||||
grid position coordinates given in `coords`. The mask being set to True means that corresponding | ||||||||
grid nodes should be included in g2m. If `decode_mask=None` (default), all grid nodes are included. | ||||||||
""" | ||||||||
graph_components: dict[networkx.DiGraph] = {} | ||||||||
|
||||||||
|
@@ -149,9 +158,19 @@ def create_all_graph_components( | |||||||
) | ||||||||
graph_components["g2m"] = G_g2m | ||||||||
|
||||||||
if decode_mask is None: | ||||||||
# decode to all grid nodes | ||||||||
decode_grid = G_grid | ||||||||
else: | ||||||||
# Select subset of grid nodes to decode to, where m2g should connect | ||||||||
filter_nodes = [ | ||||||||
n for n, include in zip(G_grid.nodes, decode_mask, strict=True) if include | ||||||||
] | ||||||||
decode_grid = G_grid.subgraph(filter_nodes) | ||||||||
|
||||||||
G_m2g = connect_nodes_across_graphs( | ||||||||
G_source=grid_connect_graph, | ||||||||
G_target=G_grid, | ||||||||
G_target=decode_grid, | ||||||||
method=m2g_connectivity, | ||||||||
**m2g_connectivity_kwargs, | ||||||||
) | ||||||||
|
@@ -162,6 +181,14 @@ def create_all_graph_components( | |||||||
for edge in graph.edges: | ||||||||
graph.edges[edge]["component"] = name | ||||||||
|
||||||||
if return_components: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we add a comment here, something like:
Suggested change
|
||||||||
# Give each component unique ids | ||||||||
graph_components = { | ||||||||
comp_name: replace_node_labels_with_unique_ids(subgraph) | ||||||||
for comp_name, subgraph in graph_components.items() | ||||||||
} | ||||||||
return graph_components | ||||||||
|
||||||||
# merge to single graph | ||||||||
G_tot = networkx.compose_all(graph_components.values()) | ||||||||
# only keep graph attributes that are the same for all components | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -98,40 +98,29 @@ def split_graph_by_edge_attribute(graph, attr): | |||||
f"No subgraphs were created. Check the edge attribute '{attr}'." | ||||||
) | ||||||
|
||||||
# copy node attributes | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wow, what a mess 😮 I don't know what I did here... Good you are removing these extra lines |
||||||
for subgraph in subgraphs.values(): | ||||||
for node in subgraph.nodes: | ||||||
subgraph.nodes[node].update(graph.nodes[node]) | ||||||
|
||||||
# check that at least one subgraph was created | ||||||
if len(subgraphs) == 0: | ||||||
raise ValueError( | ||||||
f"No subgraphs were created. Check the edge attribute '{attr}'." | ||||||
) | ||||||
return subgraphs | ||||||
|
||||||
# copy node attributes | ||||||
for subgraph in subgraphs.values(): | ||||||
for node in subgraph.nodes: | ||||||
subgraph.nodes[node].update(graph.nodes[node]) | ||||||
|
||||||
# check that at least one subgraph was created | ||||||
if len(subgraphs) == 0: | ||||||
raise ValueError( | ||||||
f"No subgraphs were created. Check the edge attribute '{attr}'." | ||||||
) | ||||||
def sort_nodes_in_graph(graph): | ||||||
""" | ||||||
Creates a new networkx.DiGraph that is a copy of input, but with nodes | ||||||
sorted according to their id | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think strictly speaking nodes are identified by something called their "label" rather than "id" (e.g. https://networkx.org/documentation/stable/reference/generated/networkx.relabel.convert_node_labels_to_integers.html#networkx.relabel.convert_node_labels_to_integers). And in our case we use integers as labels |
||||||
|
||||||
# copy node attributes | ||||||
for subgraph in subgraphs.values(): | ||||||
for node in subgraph.nodes: | ||||||
subgraph.nodes[node].update(graph.nodes[node]) | ||||||
Parameters | ||||||
---------- | ||||||
graph : networkx.DiGraph | ||||||
Graph to sort nodes from | ||||||
|
||||||
# check that at least one subgraph was created | ||||||
if len(subgraphs) == 0: | ||||||
raise ValueError( | ||||||
f"No subgraphs were created. Check the edge attribute '{attr}'." | ||||||
) | ||||||
Returns | ||||||
------- | ||||||
networkx.DiGraph | ||||||
Graph with sorted nodes | ||||||
""" | ||||||
sorted_graph = networkx.DiGraph() | ||||||
sorted_graph.add_nodes_from(sorted(graph.nodes(data=True))) | ||||||
sorted_graph.add_edges_from(graph.edges(data=True)) | ||||||
|
||||||
return subgraphs | ||||||
return sorted_graph | ||||||
|
||||||
|
||||||
def replace_node_labels_with_unique_ids(graph): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am realising now that this is a poor name, maybe I should have called it |
||||||
|
@@ -149,7 +138,7 @@ def replace_node_labels_with_unique_ids(graph): | |||||
Graph with node labels renamed | ||||||
""" | ||||||
return networkx.relabel_nodes( | ||||||
graph, {node: i for i, node in enumerate(graph.nodes)}, copy=True | ||||||
graph, {node: i for i, node in enumerate(sorted(graph.nodes))}, copy=True | ||||||
) | ||||||
|
||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to keep the explicit arguments here. This would mean that this file is mostly unchanged, although there was a missing docstring for
max_num_levels
, increate_oskarsson_hierarchical_graph()
that you have added in this PR