diff --git a/docs/user/quickstart.rst b/docs/user/quickstart.rst index 4640ed5f..58cfad26 100644 --- a/docs/user/quickstart.rst +++ b/docs/user/quickstart.rst @@ -65,13 +65,13 @@ Partitioning the graph with an initial districting plan Now that we have a graph, we can partition it into districts. Our shapefile has a data column called ``"DISTRICT"`` assigning a district ID to each node in our adjacency graph. -We can use this assignment to instantiate a :class:`~gerrychain.Partition` object:: +We can use this assignment to instantiate a :class:`~gerrychain.GeographicPartition` object:: - from gerrychain import Partition + from gerrychain import GeographicPartition - initial_partition = Partition(vtds_graph, assignment="DISTRICT") + initial_partition = GeographicPartition(vtds_graph, assignment="DISTRICT") -We set ``assignment="DISTRICT"`` to tell the :class:`~gerrychain.Partition` object to use +We set ``assignment="DISTRICT"`` to tell the :class:`~gerrychain.GeographicPartition` object to use the ``"DISTRICT"`` node attribute to assign nodes into districts. The ``assignment`` parameter could also have been a dictionary from node ID to district ID. This is useful when your adjacency graph and districting plan data are coming from two separate sources. @@ -114,7 +114,7 @@ we see that the value of the ``perimeter`` attribute is itself a dictionary mapp the perimeter of the district. Under the hood, these attributes are computed by "updater" functions. The user can pass their own -``updaters`` dictionary when instantiating a ``Partition``, and the values will be accessible just like -the ``perimeter`` attribute above. For more details, see :mod:`gerrychain.updaters`. +``updaters`` dictionary when instantiating a partition, and the values will be accessible by key using the +same syntax as the ``perimeter`` attribute above. For more details, see :mod:`gerrychain.updaters`. .. TODO: Elections \ No newline at end of file diff --git a/gerrychain/graph/graph.py b/gerrychain/graph/graph.py index 2030651c..ed0a1540 100644 --- a/gerrychain/graph/graph.py +++ b/gerrychain/graph/graph.py @@ -43,6 +43,7 @@ def from_file( df = gp.read_file(filename) graph = cls.from_geodataframe(df, adjacency, reproject) graph.add_data(df, columns=cols_to_add) + return graph @classmethod def from_geodataframe(cls, dataframe, adjacency=Adjacency.Rook, reproject=True): @@ -108,18 +109,21 @@ def add_data(self, df, columns=None): column_dictionaries = df.to_dict("index") networkx.set_node_attributes(self, column_dictionaries) - def assignment(self, node_attribute_key): - """Create an assignment dictionary using an attribute of the nodes - of the graph. For example, if you created your graph from Census data + def node_attribute(self, node_attribute_key): + """Create a dictionary of the form ``{node: }`` for + the given attribute key, over all nodes of the graph. + + This is useful for creating an assignment dictionary from an attribute + from a source data file. For example, if you created your graph from Census data and each node has a `CD` attribute that gives the congressional district - the node belongs to, then `graph.assignment("CD")` would return the + the node belongs to, then `graph.node_attribute("CD")` would return the desired assignment of nodes to CDs. :param graph: NetworkX graph. :param node_attribute_key: Attribute available on all nodes. :return: Dictionary of {node_id: attribute} pairs. """ - return networkx.get_node_attributes(self, node_attribute_key) + return {node: data[node_attribute_key] for node, data in self.nodes.items()} def join(self, dataframe, columns=None, left_index=None, right_index=None): """Add data from a dataframe to the graph, matching nodes to rows when diff --git a/gerrychain/partition/partition.py b/gerrychain/partition/partition.py index 73b9ebb6..735feac9 100644 --- a/gerrychain/partition/partition.py +++ b/gerrychain/partition/partition.py @@ -1,8 +1,7 @@ import collections from gerrychain.graph import Graph -from gerrychain.updaters import (compute_edge_flows, cut_edges, - flows_from_changes) +from gerrychain.updaters import compute_edge_flows, cut_edges, flows_from_changes class Partition: @@ -12,10 +11,12 @@ class Partition: aggregations and calculations that we want to optimize. """ - default_updaters = {'cut_edges': cut_edges} - def __init__(self, graph=None, assignment=None, updaters=None, - parent=None, flips=None): + default_updaters = {"cut_edges": cut_edges} + + def __init__( + self, graph=None, assignment=None, updaters=None, parent=None, flips=None + ): """ :param graph: Underlying graph; a NetworkX object. :param assignment: Dictionary assigning nodes to districts. If None, @@ -36,7 +37,7 @@ def _first_time(self, graph, assignment, updaters): self.graph = graph if isinstance(assignment, str): - assignment = graph.assignment(assignment) + assignment = graph.node_attribute(assignment) elif not isinstance(assignment, dict): raise TypeError("Assignment must be a dict or a node attribute key") self.assignment = assignment diff --git a/tests/test_make_graph.py b/tests/test_make_graph.py index abef6f19..390290d4 100644 --- a/tests/test_make_graph.py +++ b/tests/test_make_graph.py @@ -1,3 +1,6 @@ +import pathlib +from tempfile import TemporaryDirectory + import geopandas as gp import pandas import pytest @@ -13,10 +16,17 @@ def geodataframe(): c = Polygon([(1, 0), (1, 1), (2, 1), (2, 0)]) d = Polygon([(1, 1), (1, 2), (2, 2), (2, 1)]) df = gp.GeoDataFrame({"ID": ["a", "b", "c", "d"], "geometry": [a, b, c, d]}) - df.crs = "+init=epsg:4326" + df.crs = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs" return df +@pytest.fixture +def gdf_with_data(geodataframe): + geodataframe["data"] = list(range(len(geodataframe))) + geodataframe["data2"] = list(range(len(geodataframe))) + return geodataframe + + @pytest.fixture def geodataframe_with_boundary(): """ @@ -30,7 +40,7 @@ def geodataframe_with_boundary(): d = Polygon([(1, 1), (1, 2), (2, 2), (2, 1)]) e = Polygon([(2, 0), (2, 1), (2, 2), (2, 3), (3, 3), (3, 2), (3, 1), (3, 0)]) df = gp.GeoDataFrame({"ID": ["a", "b", "c", "d", "e"], "geometry": [a, b, c, d, e]}) - df.crs = "+init=epsg:4326" + df.crs = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs" return df @@ -142,3 +152,21 @@ def test_computes_boundary_perims(geodataframe_with_boundary): def edge_set_equal(set1, set2): return {(y, x) for x, y in set1} | set1 == {(y, x) for x, y in set2} | set2 + + +def test_from_file_adds_all_data_by_default(gdf_with_data): + with TemporaryDirectory() as d: + filepath = pathlib.Path(d) / "temp.shp" + filename = str(filepath.absolute()) + gdf_with_data.to_file(filename) + graph = Graph.from_file(filename) + + assert all("data" in node_data for node_data in graph.nodes.values()) + assert all("data2" in node_data for node_data in graph.nodes.values()) + + +def test_graph_assignment_raises_if_data_is_missing(): + graph = Graph([(1, 2), (2, 3), (3, 1)]) + + with pytest.raises(KeyError): + graph.node_attribute("missing_data_key")