Skip to content

Commit

Permalink
wip: working on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Jan 13, 2025
1 parent a4adc29 commit b0d9f58
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 24 deletions.
2 changes: 1 addition & 1 deletion jaxley/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from jaxley.io.graph import from_graph, to_graph
# from jaxley.io.graph import from_graph, to_graph # Leads to circular import
from jaxley.io.swc import read_swc
6 changes: 3 additions & 3 deletions jaxley/io/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def trace_branches(
graph.add_edge(-1, soma, l=2 * graph.nodes[soma]["r"])
graph = nx.relabel_nodes(graph, {i: i + 1 for i in graph.nodes})

# edges connecting nodes to soma are considered part of the soma -> l = 0.
for i, j in (*graph.in_edges(soma), *graph.out_edges(soma)):
graph.edges[i, j]["l"] = 0
# # edges connecting nodes to soma are considered part of the soma -> l = 0.
# for i, j in (*graph.in_edges(soma), *graph.out_edges(soma)):
# graph.edges[i, j]["l"] = 0

# ensure linear root segment to ensure root branch can be created.
if graph.out_degree(0) > 1:
Expand Down
1 change: 0 additions & 1 deletion jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
)



class Cell(Module):
"""Cell class.
Expand Down
23 changes: 17 additions & 6 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,18 @@ def jaxley2neuron_by_coords(jx_cell, neuron_secs, comp_idx=None, loc=None, nseg=
neuron_coords = np.vstack(
[np.hstack([k * np.ones((v.shape[0], 1)), v]) for k, v in neuron_coords.items()]
)
neuron_coords = pd.DataFrame(neuron_coords, columns=["branch_index", "x", "y", "z"])
neuron_coords["branch_index"] = neuron_coords["branch_index"].astype(int)
neuron_coords = pd.DataFrame(
neuron_coords, columns=["global_branch_index", "x", "y", "z"]
)
neuron_coords["global_branch_index"] = neuron_coords["global_branch_index"].astype(
int
)

neuron_loc_xyz = neuron_coords.groupby("branch_index").mean()
neuron_loc_xyz = neuron_coords.groupby("global_branch_index").mean()
jaxley_loc_xyz = (
jx_cell.branch("all").loc(loc).show().set_index("branch_index")[["x", "y", "z"]]
jx_cell.branch("all")
.loc(loc)
.nodes.set_index("global_branch_index")[["x", "y", "z"]]
)

jaxley2neuron_inds = {}
Expand All @@ -81,11 +87,16 @@ def jaxley2neuron_by_group(
num_basal=10,
):
y_apical = (
jx_cell.apical.show().groupby("branch_index").mean()["y"].abs().sort_values()
jx_cell.apical.nodes.groupby("global_branch_index")
.mean()["y"]
.abs()
.sort_values()
)
trunk_inds = y_apical.index[:num_apical].tolist()
tuft_inds = y_apical.index[-num_tuft:].tolist()
basal_inds = jx_cell.basal.show()["branch_index"].unique()[:num_basal].tolist()
basal_inds = (
jx_cell.basal.nodes["global_branch_index"].unique()[:num_basal].tolist()
)

jaxley2neuron = jaxley2neuron_by_coords(
jx_cell, neuron_secs, comp_idx=comp_idx, loc=loc, nseg=nseg
Expand Down
30 changes: 17 additions & 13 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def test_trace_branches(file):
edges = pd.DataFrame([{"u": u, "v": v, **d} for u, v, d in graph.edges(data=True)])
nx_branch_lens = edges.groupby("branch_index")["l"].sum().to_numpy()
nx_branch_lens = np.sort(nx_branch_lens)
if np.isclose(nx_branch_lens[0], 1e-1):
nx_branch_lens = nx_branch_lens[1:]

h, _ = import_neuron_morph(fname)
neuron_branch_lens = np.sort([sec.L for sec in h.allsec()])
Expand All @@ -108,23 +110,25 @@ def test_trace_branches(file):

@pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"])
def test_from_graph_vs_NEURON(file):
nseg = 8
ncomp = 8
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", file)

graph = swc_to_graph(fname)
cell = from_graph(
graph, nseg=nseg, max_branch_len=2000, ignore_swc_trace_errors=False
graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False
)
cell.compute_compartment_centers()
h, neuron_cell = import_neuron_morph(fname, nseg=nseg)
h, neuron_cell = import_neuron_morph(fname, nseg=ncomp)

# remove root branch
jaxley_comps = cell.nodes[
~np.isclose(cell.nodes["length"], 0.1 / nseg)
~np.isclose(cell.nodes["length"], 0.1 / ncomp)
].reset_index(drop=True)

jx_branch_lens = jaxley_comps.groupby("branch_index")["length"].sum().to_numpy()
jx_branch_lens = (
jaxley_comps.groupby("global_branch_index")["length"].sum().to_numpy()
)

# match by branch lengths
neuron_xyzd = [np.array(s.psection()["morphology"]["pts3d"]) for s in h.allsec()]
Expand All @@ -143,15 +147,15 @@ def test_from_graph_vs_NEURON(file):
neuron_comp_k = np.array(
[
get_segment_xyzrL(list(h.allsec())[neuron_inds[k]], comp_idx=i)
for i in range(nseg)
for i in range(ncomp)
]
)
# make this a dataframe
neuron_comp_k = pd.DataFrame(
neuron_comp_k, columns=["x", "y", "z", "radius", "length"]
)
neuron_comp_k["idx"] = neuron_inds[k]
jx_comp_k = jaxley_comps[jaxley_comps["branch_index"] == jx_inds[k]][
jx_comp_k = jaxley_comps[jaxley_comps["global_branch_index"] == jx_inds[k]][
["x", "y", "z", "radius", "length"]
]
jx_comp_k["idx"] = jx_inds[k]
Expand Down Expand Up @@ -190,7 +194,7 @@ def test_graph_to_jaxley(file):
graph = swc_to_graph(fname)
swc_module = from_graph(graph)
for group in ["soma", "apical", "basal"]:
assert group in swc_module.group_nodes
assert group in swc_module.groups

# test import after different stages of graph pre-processing
graph = swc_to_graph(fname)
Expand All @@ -216,7 +220,7 @@ def test_swc2graph_voltages(file):
dirname = os.path.dirname(__file__)
fname = os.path.join(dirname, "swc_files", file) # n120

nseg = 8
ncomp = 8

i_delay = 2.0
i_dur = 5.0
Expand All @@ -225,14 +229,14 @@ def test_swc2graph_voltages(file):
dt = 0.025

##################### NEURON ##################
h, neuron_cell = import_neuron_morph(fname, nseg=nseg)
h, neuron_cell = import_neuron_morph(fname, nseg=ncomp)

####################### jaxley ##################
graph = swc_to_graph(fname)
jx_cell = from_graph(
graph, nseg=nseg, max_branch_len=2000, ignore_swc_trace_errors=False
graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False
)
jx_cell._update_nodes_with_xyz()
jx_cell.compute_compartment_centers()
jx_cell.insert(HH())

branch_loc = 0.05
Expand All @@ -255,7 +259,7 @@ def test_swc2graph_voltages(file):
jx_cell.set("HH_h", 0.4889)
jx_cell.set("HH_n", 0.3644787)

jx_cell.branch.comp(stim_idx).stimulate(
jx_cell.select(stim_idx).stimulate(
jx.step_current(i_delay, i_dur, i_amp, dt, t_max)
)
for i in trunk_inds + tuft_inds + basal_inds:
Expand Down

0 comments on commit b0d9f58

Please sign in to comment.