Skip to content

Commit

Permalink
wip: all but one test passing, working on import export cycle validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Jan 14, 2025
1 parent b0d9f58 commit c0607bb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
13 changes: 6 additions & 7 deletions jaxley/io/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,16 @@ def trace_branches(
# handles special case of a single soma node
if len(soma_idxs := get_soma_idxs(graph)) == 1:
soma = soma_idxs[0]
# edges connecting nodes to soma are considered part of the soma -> l = 0.
for i, j in (*graph.in_edges(soma), *graph.out_edges(soma)):
graph.edges[i, j]["l"] = 0

# Setting l = 2*r ensures A_cylinder = 2*pi*r*l = 4*pi*r^2 = A_sphere
graph.add_node(-1, **graph.nodes[0])
graph.add_edge(-1, soma, l=2 * graph.nodes[soma]["r"])
graph = nx.relabel_nodes(graph, {i: i + 1 for i in graph.nodes})

# # edges connecting nodes to soma are considered part of the soma -> l = 0.
# for i, j in (*graph.in_edges(soma), *graph.out_edges(soma)):
# graph.edges[i, j]["l"] = 0

# ensure linear root segment to ensure root branch can be created.
# Ensure root segment is linear. Needed to create root branch.
if graph.out_degree(0) > 1:
graph.add_node(-1, **graph.nodes[0])
graph.add_edge(-1, 0, l=0.1)
Expand Down Expand Up @@ -833,8 +833,6 @@ def to_graph(module: jx.Module) -> nx.DiGraph:
columns=["group", "index"],
)
nodes = pd.concat([nodes, group_inds.groupby("index")["group"].agg(list)], axis=1)

module_graph = nx.DiGraph()
module_graph.add_nodes_from(nodes.T.to_dict().items())

inter_branch_edges = module.branch_edges.copy()
Expand Down Expand Up @@ -864,4 +862,5 @@ def to_graph(module: jx.Module) -> nx.DiGraph:
)

module_graph.graph["type"] = module.__class__.__name__.lower()

return module_graph
9 changes: 7 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,15 @@


# test exporting and re-importing of different modules
def test_graph_import_export_cycle(SimpleComp, SimpleBranch, SimpleCell, SimpleNetwork):
def test_graph_import_export_cycle(
SimpleComp, SimpleBranch, SimpleCell, SimpleNetwork, SimpleMorphCell
):
# build a network
np.random.seed(0)
comp = SimpleComp()
branch = SimpleBranch(4)
cell = SimpleCell(5, 4)
morph_cell = SimpleMorphCell()
net = SimpleNetwork(3, 5, 4)

# add synapses
Expand All @@ -66,7 +69,7 @@ def test_graph_import_export_cycle(SimpleComp, SimpleBranch, SimpleCell, SimpleN
net.cell(0).insert(K())

# test consistency of exported and re-imported modules
for module in [net, cell, branch, comp]:
for module in [net, morph_cell, cell, branch, comp]:
module.compute_xyz() # ensure x,y,z in nodes b4 exporting for later comparison
module_graph = to_graph(module) # ensure to_graph works
re_module = from_graph(module_graph) # ensure prev exported graph can be read
Expand Down Expand Up @@ -97,6 +100,8 @@ def test_trace_branches(file):
edges = pd.DataFrame([{"u": u, "v": v, **d} for u, v, d in graph.edges(data=True)])
nx_branch_lens = edges.groupby("branch_index")["l"].sum().to_numpy()
nx_branch_lens = np.sort(nx_branch_lens)

# exclude artificial root branch
if np.isclose(nx_branch_lens[0], 1e-1):
nx_branch_lens = nx_branch_lens[1:]

Expand Down

0 comments on commit c0607bb

Please sign in to comment.