diff --git a/jaxley/io/graph.py b/jaxley/io/graph.py index bc510209..40bd22d7 100644 --- a/jaxley/io/graph.py +++ b/jaxley/io/graph.py @@ -266,6 +266,23 @@ def split_branches( def insert_compartments(graph: nx.DiGraph, ncomp_per_branch: int) -> nx.DiGraph: + """Insert compartment nodes into the graph. + + Inserts new nodes in every branch (edges with "branch_index" attribute) at equidistant + points along it. Node attributes, like radius are linearly interpolated along its + length. + + Example: 4 compartments | edges = - | nodes = o | comp_nodes = x + o-----------o----------o---o---o---o--------o + o-------x---o----x-----o--xo---o---ox-------o + + Args: + graph: Mmorphology where edges are already labelled with "branch_index" + ncomp_per_branch: How many compartments per branch to insert + + Returns: + Graph with additional nodes that are labelled with "comp_index" + """ comp_offset = 0 branch_inds = nx.get_edge_attributes(graph, "branch_index") @@ -319,7 +336,24 @@ def insert_compartments(graph: nx.DiGraph, ncomp_per_branch: int) -> nx.DiGraph: return graph -def get_comp_edges_dfs(graph, node, parent_comp=None, visited=None): +def get_comp_edges_dfs( + graph: nx.DiGraph, node: int, parent_comp: int = None, visited: set = None +) -> List[Tuple[int]]: + """List edges between compartment nodes, ignoring non-compartment nodes. + + Traverses a graph depth first and only records nodes and their successors if they + have a "compartment_index". + + Args: + graph: Morphology with inserted compartment nodes. + node: node_index from which to start depth first traversal + parent_comp: node_index of parent compartment + visited: Keeps track of visited nodes during traversal + + Returns: + List of edges (parent_node_index, child_node_index) that directly connect + compartments, while skipping indermediate hops via nodes w.o. "comp_index" attr. + """ if visited is None: visited = set() edges = [] @@ -343,7 +377,19 @@ def get_comp_edges_dfs(graph, node, parent_comp=None, visited=None): return edges -def extract_comp_graph(graph): +def extract_comp_graph(graph: nx.DiGraph) -> nx.DiGraph: + """Get subgraph that only includes compartment nodes and their direct edges. + + Example: 4 compartments | edges = - | nodes = o | comp_nodes = x + o-------x---o----x-----o--xo---o---ox-------o + x--------x--------x---------x + + Args: + graph: Morphology with compartment nodes + + Returns: + Morphology with branches and compartments. + """ # create subgraph w.o. edges comp_nodes = list(nx.get_node_attributes(graph, "comp_index")) comp_graph = nx.subgraph(graph, comp_nodes).copy()