diff --git a/jaxley/io/__init__.py b/jaxley/io/__init__.py index 2f211300..23a386cc 100644 --- a/jaxley/io/__init__.py +++ b/jaxley/io/__init__.py @@ -1,5 +1,5 @@ # This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is # licensed under the Apache License Version 2.0, see -from jaxley.io.graph import from_graph, to_graph +# from jaxley.io.graph import from_graph, to_graph # Leads to circular import from jaxley.io.swc import read_swc diff --git a/jaxley/io/graph.py b/jaxley/io/graph.py index 04baaa8c..3525e382 100644 --- a/jaxley/io/graph.py +++ b/jaxley/io/graph.py @@ -176,9 +176,9 @@ def trace_branches( graph.add_edge(-1, soma, l=2 * graph.nodes[soma]["r"]) graph = nx.relabel_nodes(graph, {i: i + 1 for i in graph.nodes}) - # edges connecting nodes to soma are considered part of the soma -> l = 0. - for i, j in (*graph.in_edges(soma), *graph.out_edges(soma)): - graph.edges[i, j]["l"] = 0 + # # edges connecting nodes to soma are considered part of the soma -> l = 0. + # for i, j in (*graph.in_edges(soma), *graph.out_edges(soma)): + # graph.edges[i, j]["l"] = 0 # ensure linear root segment to ensure root branch can be created. if graph.out_degree(0) > 1: diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 97d5549e..3d6b39da 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -27,7 +27,6 @@ ) - class Cell(Module): """Cell class. diff --git a/tests/helpers.py b/tests/helpers.py index bd9c5316..172b5e12 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -55,12 +55,18 @@ def jaxley2neuron_by_coords(jx_cell, neuron_secs, comp_idx=None, loc=None, nseg= neuron_coords = np.vstack( [np.hstack([k * np.ones((v.shape[0], 1)), v]) for k, v in neuron_coords.items()] ) - neuron_coords = pd.DataFrame(neuron_coords, columns=["branch_index", "x", "y", "z"]) - neuron_coords["branch_index"] = neuron_coords["branch_index"].astype(int) + neuron_coords = pd.DataFrame( + neuron_coords, columns=["global_branch_index", "x", "y", "z"] + ) + neuron_coords["global_branch_index"] = neuron_coords["global_branch_index"].astype( + int + ) - neuron_loc_xyz = neuron_coords.groupby("branch_index").mean() + neuron_loc_xyz = neuron_coords.groupby("global_branch_index").mean() jaxley_loc_xyz = ( - jx_cell.branch("all").loc(loc).show().set_index("branch_index")[["x", "y", "z"]] + jx_cell.branch("all") + .loc(loc) + .nodes.set_index("global_branch_index")[["x", "y", "z"]] ) jaxley2neuron_inds = {} @@ -81,11 +87,16 @@ def jaxley2neuron_by_group( num_basal=10, ): y_apical = ( - jx_cell.apical.show().groupby("branch_index").mean()["y"].abs().sort_values() + jx_cell.apical.nodes.groupby("global_branch_index") + .mean()["y"] + .abs() + .sort_values() ) trunk_inds = y_apical.index[:num_apical].tolist() tuft_inds = y_apical.index[-num_tuft:].tolist() - basal_inds = jx_cell.basal.show()["branch_index"].unique()[:num_basal].tolist() + basal_inds = ( + jx_cell.basal.nodes["global_branch_index"].unique()[:num_basal].tolist() + ) jaxley2neuron = jaxley2neuron_by_coords( jx_cell, neuron_secs, comp_idx=comp_idx, loc=loc, nseg=nseg diff --git a/tests/test_graph.py b/tests/test_graph.py index 72b2069b..3e6b2802 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -97,6 +97,8 @@ def test_trace_branches(file): edges = pd.DataFrame([{"u": u, "v": v, **d} for u, v, d in graph.edges(data=True)]) nx_branch_lens = edges.groupby("branch_index")["l"].sum().to_numpy() nx_branch_lens = np.sort(nx_branch_lens) + if np.isclose(nx_branch_lens[0], 1e-1): + nx_branch_lens = nx_branch_lens[1:] h, _ = import_neuron_morph(fname) neuron_branch_lens = np.sort([sec.L for sec in h.allsec()]) @@ -108,23 +110,25 @@ def test_trace_branches(file): @pytest.mark.parametrize("file", ["morph_single_point_soma.swc", "morph.swc"]) def test_from_graph_vs_NEURON(file): - nseg = 8 + ncomp = 8 dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) graph = swc_to_graph(fname) cell = from_graph( - graph, nseg=nseg, max_branch_len=2000, ignore_swc_trace_errors=False + graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False ) cell.compute_compartment_centers() - h, neuron_cell = import_neuron_morph(fname, nseg=nseg) + h, neuron_cell = import_neuron_morph(fname, nseg=ncomp) # remove root branch jaxley_comps = cell.nodes[ - ~np.isclose(cell.nodes["length"], 0.1 / nseg) + ~np.isclose(cell.nodes["length"], 0.1 / ncomp) ].reset_index(drop=True) - jx_branch_lens = jaxley_comps.groupby("branch_index")["length"].sum().to_numpy() + jx_branch_lens = ( + jaxley_comps.groupby("global_branch_index")["length"].sum().to_numpy() + ) # match by branch lengths neuron_xyzd = [np.array(s.psection()["morphology"]["pts3d"]) for s in h.allsec()] @@ -143,7 +147,7 @@ def test_from_graph_vs_NEURON(file): neuron_comp_k = np.array( [ get_segment_xyzrL(list(h.allsec())[neuron_inds[k]], comp_idx=i) - for i in range(nseg) + for i in range(ncomp) ] ) # make this a dataframe @@ -151,7 +155,7 @@ def test_from_graph_vs_NEURON(file): neuron_comp_k, columns=["x", "y", "z", "radius", "length"] ) neuron_comp_k["idx"] = neuron_inds[k] - jx_comp_k = jaxley_comps[jaxley_comps["branch_index"] == jx_inds[k]][ + jx_comp_k = jaxley_comps[jaxley_comps["global_branch_index"] == jx_inds[k]][ ["x", "y", "z", "radius", "length"] ] jx_comp_k["idx"] = jx_inds[k] @@ -190,7 +194,7 @@ def test_graph_to_jaxley(file): graph = swc_to_graph(fname) swc_module = from_graph(graph) for group in ["soma", "apical", "basal"]: - assert group in swc_module.group_nodes + assert group in swc_module.groups # test import after different stages of graph pre-processing graph = swc_to_graph(fname) @@ -216,7 +220,7 @@ def test_swc2graph_voltages(file): dirname = os.path.dirname(__file__) fname = os.path.join(dirname, "swc_files", file) # n120 - nseg = 8 + ncomp = 8 i_delay = 2.0 i_dur = 5.0 @@ -225,14 +229,14 @@ def test_swc2graph_voltages(file): dt = 0.025 ##################### NEURON ################## - h, neuron_cell = import_neuron_morph(fname, nseg=nseg) + h, neuron_cell = import_neuron_morph(fname, nseg=ncomp) ####################### jaxley ################## graph = swc_to_graph(fname) jx_cell = from_graph( - graph, nseg=nseg, max_branch_len=2000, ignore_swc_trace_errors=False + graph, ncomp=ncomp, max_branch_len=2000, ignore_swc_trace_errors=False ) - jx_cell._update_nodes_with_xyz() + jx_cell.compute_compartment_centers() jx_cell.insert(HH()) branch_loc = 0.05 @@ -255,7 +259,7 @@ def test_swc2graph_voltages(file): jx_cell.set("HH_h", 0.4889) jx_cell.set("HH_n", 0.3644787) - jx_cell.branch.comp(stim_idx).stimulate( + jx_cell.select(stim_idx).stimulate( jx.step_current(i_delay, i_dur, i_amp, dt, t_max) ) for i in trunk_inds + tuft_inds + basal_inds: