Skip to content

Commit

Permalink
Merge pull request #234 from maxhully/master
Browse files Browse the repository at this point in the history
Bring `Graph` in line with Getting started guide; Add test for `from_file()`
  • Loading branch information
maxhully authored Oct 18, 2018
2 parents 79b8781 + 1fa8451 commit 875a5d6
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 19 deletions.
12 changes: 6 additions & 6 deletions docs/user/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ Partitioning the graph with an initial districting plan

Now that we have a graph, we can partition it into districts. Our shapefile has a data
column called ``"DISTRICT"`` assigning a district ID to each node in our adjacency graph.
We can use this assignment to instantiate a :class:`~gerrychain.Partition` object::
We can use this assignment to instantiate a :class:`~gerrychain.GeographicPartition` object::

from gerrychain import Partition
from gerrychain import GeographicPartition

initial_partition = Partition(vtds_graph, assignment="DISTRICT")
initial_partition = GeographicPartition(vtds_graph, assignment="DISTRICT")

We set ``assignment="DISTRICT"`` to tell the :class:`~gerrychain.Partition` object to use
We set ``assignment="DISTRICT"`` to tell the :class:`~gerrychain.GeographicPartition` object to use
the ``"DISTRICT"`` node attribute to assign nodes into districts. The ``assignment``
parameter could also have been a dictionary from node ID to district ID. This is useful
when your adjacency graph and districting plan data are coming from two separate sources.
Expand Down Expand Up @@ -114,7 +114,7 @@ we see that the value of the ``perimeter`` attribute is itself a dictionary mapp
the perimeter of the district.

Under the hood, these attributes are computed by "updater" functions. The user can pass their own
``updaters`` dictionary when instantiating a ``Partition``, and the values will be accessible just like
the ``perimeter`` attribute above. For more details, see :mod:`gerrychain.updaters`.
``updaters`` dictionary when instantiating a partition, and the values will be accessible by key using the
same syntax as the ``perimeter`` attribute above. For more details, see :mod:`gerrychain.updaters`.

.. TODO: Elections
14 changes: 9 additions & 5 deletions gerrychain/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def from_file(
df = gp.read_file(filename)
graph = cls.from_geodataframe(df, adjacency, reproject)
graph.add_data(df, columns=cols_to_add)
return graph

@classmethod
def from_geodataframe(cls, dataframe, adjacency=Adjacency.Rook, reproject=True):
Expand Down Expand Up @@ -108,18 +109,21 @@ def add_data(self, df, columns=None):
column_dictionaries = df.to_dict("index")
networkx.set_node_attributes(self, column_dictionaries)

def assignment(self, node_attribute_key):
"""Create an assignment dictionary using an attribute of the nodes
of the graph. For example, if you created your graph from Census data
def node_attribute(self, node_attribute_key):
"""Create a dictionary of the form ``{node: <attribute value>}`` for
the given attribute key, over all nodes of the graph.
This is useful for creating an assignment dictionary from an attribute
from a source data file. For example, if you created your graph from Census data
and each node has a `CD` attribute that gives the congressional district
the node belongs to, then `graph.assignment("CD")` would return the
the node belongs to, then `graph.node_attribute("CD")` would return the
desired assignment of nodes to CDs.
:param graph: NetworkX graph.
:param node_attribute_key: Attribute available on all nodes.
:return: Dictionary of {node_id: attribute} pairs.
"""
return networkx.get_node_attributes(self, node_attribute_key)
return {node: data[node_attribute_key] for node, data in self.nodes.items()}

def join(self, dataframe, columns=None, left_index=None, right_index=None):
"""Add data from a dataframe to the graph, matching nodes to rows when
Expand Down
13 changes: 7 additions & 6 deletions gerrychain/partition/partition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import collections

from gerrychain.graph import Graph
from gerrychain.updaters import (compute_edge_flows, cut_edges,
flows_from_changes)
from gerrychain.updaters import compute_edge_flows, cut_edges, flows_from_changes


class Partition:
Expand All @@ -12,10 +11,12 @@ class Partition:
aggregations and calculations that we want to optimize.
"""
default_updaters = {'cut_edges': cut_edges}

def __init__(self, graph=None, assignment=None, updaters=None,
parent=None, flips=None):
default_updaters = {"cut_edges": cut_edges}

def __init__(
self, graph=None, assignment=None, updaters=None, parent=None, flips=None
):
"""
:param graph: Underlying graph; a NetworkX object.
:param assignment: Dictionary assigning nodes to districts. If None,
Expand All @@ -36,7 +37,7 @@ def _first_time(self, graph, assignment, updaters):
self.graph = graph

if isinstance(assignment, str):
assignment = graph.assignment(assignment)
assignment = graph.node_attribute(assignment)
elif not isinstance(assignment, dict):
raise TypeError("Assignment must be a dict or a node attribute key")
self.assignment = assignment
Expand Down
32 changes: 30 additions & 2 deletions tests/test_make_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pathlib
from tempfile import TemporaryDirectory

import geopandas as gp
import pandas
import pytest
Expand All @@ -13,10 +16,17 @@ def geodataframe():
c = Polygon([(1, 0), (1, 1), (2, 1), (2, 0)])
d = Polygon([(1, 1), (1, 2), (2, 2), (2, 1)])
df = gp.GeoDataFrame({"ID": ["a", "b", "c", "d"], "geometry": [a, b, c, d]})
df.crs = "+init=epsg:4326"
df.crs = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs"
return df


@pytest.fixture
def gdf_with_data(geodataframe):
geodataframe["data"] = list(range(len(geodataframe)))
geodataframe["data2"] = list(range(len(geodataframe)))
return geodataframe


@pytest.fixture
def geodataframe_with_boundary():
"""
Expand All @@ -30,7 +40,7 @@ def geodataframe_with_boundary():
d = Polygon([(1, 1), (1, 2), (2, 2), (2, 1)])
e = Polygon([(2, 0), (2, 1), (2, 2), (2, 3), (3, 3), (3, 2), (3, 1), (3, 0)])
df = gp.GeoDataFrame({"ID": ["a", "b", "c", "d", "e"], "geometry": [a, b, c, d, e]})
df.crs = "+init=epsg:4326"
df.crs = "+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs"
return df


Expand Down Expand Up @@ -142,3 +152,21 @@ def test_computes_boundary_perims(geodataframe_with_boundary):

def edge_set_equal(set1, set2):
return {(y, x) for x, y in set1} | set1 == {(y, x) for x, y in set2} | set2


def test_from_file_adds_all_data_by_default(gdf_with_data):
with TemporaryDirectory() as d:
filepath = pathlib.Path(d) / "temp.shp"
filename = str(filepath.absolute())
gdf_with_data.to_file(filename)
graph = Graph.from_file(filename)

assert all("data" in node_data for node_data in graph.nodes.values())
assert all("data2" in node_data for node_data in graph.nodes.values())


def test_graph_assignment_raises_if_data_is_missing():
graph = Graph([(1, 2), (2, 3), (3, 1)])

with pytest.raises(KeyError):
graph.node_attribute("missing_data_key")

0 comments on commit 875a5d6

Please sign in to comment.