Skip to content

Commit

Permalink
Fixups for having heterogenous numbers of comps in each branch
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Oct 10, 2024
1 parent 4bd3fe3 commit e78bb26
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 42 deletions.
9 changes: 3 additions & 6 deletions jaxley/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ def sparse_connect(
post_rows = post_cell_view.view.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
idcs_to_zero = np.zeros_like(num_pre)
get_global_idx = pre_cell_view.pointer._local_inds_to_global
global_pre_indices = get_global_idx(pre_syn_neurons, idcs_to_zero, idcs_to_zero)
global_pre_indices = pre_cell_view.pointer._cumsum_nseg_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.view.loc[global_pre_indices]

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
Expand Down Expand Up @@ -182,9 +180,8 @@ def connectivity_matrix_connect(
]
post_rows = post_cell_view.view.loc[global_post_indices]

idcs_to_zero = np.zeros_like(from_idx)
get_global_idx = post_cell_view.pointer._local_inds_to_global
global_pre_indices = get_global_idx(pre_cell_inds, idcs_to_zero, idcs_to_zero)
# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.pointer._cumsum_nseg_per_cell[pre_cell_inds]
pre_rows = pre_cell_view.view.loc[global_pre_indices]

pre_cell_view._append_multiple_synapses(pre_rows, post_rows, synapse_type)
43 changes: 30 additions & 13 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,16 @@ def _make_trainable(
if key in view.columns:
view = view[~np.isnan(view[key])]
grouped_view = view.groupby("controlled_by_param")
num_elements_being_set = grouped_view.apply(len).to_numpy()
assert np.all(num_elements_being_set == num_elements_being_set[0]), (
"You are using `make_trainable()` with parameter sharing (e.g. same "
"parameter for an entire cell, or same parameter for entire branches). "
"This error is caused because you are trying to share a parameter "
"across an inhomogenous number of compartments. To overcome this "
"error, write a for-loop across cells (or branches). For example, "
"change `net.cell('all').make_trainable('HH_gNa')` to "
"`for i in range(num_cells): net.cell(i).make_trainable('HH_gNa')`"
)
# Because of this `x.index.values` we cannot support `make_trainable()` on
# the module level for synapse parameters (but only for `SynapseView`).
inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))
Expand Down Expand Up @@ -1439,7 +1449,11 @@ def _scatter(self, ax, col, dims, view, morph_plot_kwargs):
np.isnan(self.xyzr[branch_ind][:, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."

comp_fraction = loc_of_index(comp_ind, self.nseg)
comp_fraction = loc_of_index(
comp_ind,
branch_ind,
self.nseg_per_branch,
)
coords = self.xyzr[branch_ind]
interpolated_xyz = interpolate_xyz(comp_fraction, coords)

Expand Down Expand Up @@ -1701,15 +1715,6 @@ def __iter__(self):
for i in range(self.shape[0]):
yield self[i]

def _local_inds_to_global(
self, cell_inds: np.ndarray, branch_inds: np.ndarray, comp_inds: np.ndarray
):
"""Given local inds of cell, branch, and comp, return the global comp index."""
global_ind = (
self.cumsum_nbranches[cell_inds] + branch_inds
) * self.nseg + comp_inds
return global_ind.astype(int)


class View:
"""View of a `Module`."""
Expand Down Expand Up @@ -2030,7 +2035,11 @@ def xyzr(self) -> List[np.ndarray]:
"""
idxs = self.view.global_branch_index.unique()
if self.__class__.__name__ == "CompartmentView":
loc = loc_of_index(self.view.comp_index, self.pointer.nseg)
loc = loc_of_index(
self.view["global_comp_index"].to_numpy(),
self.view["global_branch_index"].to_numpy(),
self.pointer.nseg_per_branch,
)
return list(interpolate_xyz(loc, self.pointer.xyzr[idxs[0]]))
else:
return [self.pointer.xyzr[i] for i in idxs]
Expand All @@ -2056,8 +2065,16 @@ def _append_multiple_synapses(
if is_new: # synapse is not known
self._update_synapse_state_names(synapse_type)

post_loc = loc_of_index(post_rows["comp_index"].to_numpy(), self.pointer.nseg)
pre_loc = loc_of_index(pre_rows["comp_index"].to_numpy(), self.pointer.nseg)
post_loc = loc_of_index(
post_rows["global_comp_index"].to_numpy(),
post_rows["global_branch_index"].to_numpy(),
self.pointer.nseg_per_branch,
)
pre_loc = loc_of_index(
pre_rows["global_comp_index"].to_numpy(),
pre_rows["global_branch_index"].to_numpy(),
self.pointer.nseg_per_branch,
)

# Define new synapses. Each row is one synapse.
new_rows = dict(
Expand Down
2 changes: 2 additions & 0 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.compartment import Compartment, CompartmentView
from jaxley.utils.cell_utils import compute_children_and_parents
from jaxley.utils.misc_utils import cumsum_leading_zero
from jaxley.utils.solver_utils import comp_edges_to_indices


Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch)

# Indexing.
self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)
Expand Down
38 changes: 32 additions & 6 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from jaxley.modules.base import Module, View
from jaxley.utils.cell_utils import (
compute_children_and_parents,
index_of_loc,
interpolate_xyz,
loc_of_index,
local_index_of_loc,
)
from jaxley.utils.misc_utils import cumsum_leading_zero
from jaxley.utils.solver_utils import comp_edges_to_indices


Expand All @@ -41,6 +42,7 @@ def __init__(self):
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch)

# Setting up the `nodes` for indexing.
self.nodes = pd.DataFrame(
Expand Down Expand Up @@ -120,7 +122,21 @@ def loc(self, loc: float) -> "CompartmentView":
assert (
loc >= 0.0 and loc <= 1.0
), "Compartments must be indexed by a continuous value between 0 and 1."
index = index_of_loc(0, loc, self.pointer.nseg) if loc != "all" else "all"

branch_ind = np.unique(self.view["global_branch_index"].to_numpy())
if loc != "all" and len(branch_ind) != 1:
raise NotImplementedError(
"Using `.loc()` to index a single compartment of multiple branches is "
"not supported. Use a for loop or use `.comp` to index."
)
branch_ind = np.squeeze(branch_ind) # shape == (1,) --> shape == ()

# Cast nseg to numpy because in `local_index_of_loc` we instatiate an array
# of length `nseg`. However, if we use `.data_set()` or `.data_stimulate()`,
# the `local_index_of_loc()` method must be compatible with `jit`. Therefore,
# we have to stop this from being traced here and cast to numpy.
nsegs = np.asarray(self.pointer.nseg_per_branch)
index = local_index_of_loc(loc, branch_ind, nsegs) if loc != "all" else "all"
view = self(index)
view._has_been_called = True
return view
Expand All @@ -135,15 +151,25 @@ def distance(self, endpoint: "CompartmentView") -> float:
endpoint: The compartment to which to compute the distance to.
"""
start_branch = self.view["global_branch_index"].item()
start_comp = self.view["comp_index"].item()
start_comp = self.view["global_comp_index"].item()
start_xyz = interpolate_xyz(
loc_of_index(start_comp, self.pointer.nseg), self.pointer.xyzr[start_branch]
loc_of_index(
start_comp,
start_branch,
self.pointer.nseg_per_branch,
),
self.pointer.xyzr[start_branch],
)

end_branch = endpoint.view["global_branch_index"].item()
end_comp = endpoint.view["comp_index"].item()
end_comp = endpoint.view["global_comp_index"].item()
end_xyz = interpolate_xyz(
loc_of_index(end_comp, self.pointer.nseg), self.pointer.xyzr[end_branch]
loc_of_index(
end_comp,
end_branch,
self.pointer.nseg_per_branch,
),
self.pointer.xyzr[end_branch],
)

return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))
Expand Down
23 changes: 13 additions & 10 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pandas as pd
from jax import jit, vmap

from jaxley.utils.misc_utils import cumsum_leading_zero


def equal_segments(branch_property: list, nseg_per_branch: int):
"""Generates segments where some property is the same in each segment.
Expand Down Expand Up @@ -193,8 +195,8 @@ def get_num_neighbours(
return num_neighbours


def index_of_loc(branch_ind: int, loc: float, nseg_per_branch: int) -> int:
"""Returns the index of a segment given a loc in [0, 1] and the index of a branch.
def local_index_of_loc(loc: float, global_branch_ind: int, nseg_per_branch: int) -> int:
"""Returns the local index of a comp given a loc [0, 1] and the index of a branch.
This is used because we specify locations such as synapses as a value between 0 and
1. We have to convert this onto a discrete segment here.
Expand All @@ -205,19 +207,20 @@ def index_of_loc(branch_ind: int, loc: float, nseg_per_branch: int) -> int:
nseg_per_branch: Number of segments of each branch.
Returns:
The index of the compartment within the entire cell.
The local index of the compartment.
"""
nseg = nseg_per_branch # only for convenience.
nseg = nseg_per_branch[global_branch_ind] # only for convenience.
possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)
ind_along_branch = np.argmin(np.abs(possible_locs - loc))
return branch_ind * nseg + ind_along_branch
return ind_along_branch


def loc_of_index(global_comp_index, nseg):
"""Return location corresponding to index."""
index = global_comp_index % nseg
possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)
return possible_locs[index]
def loc_of_index(global_comp_index, global_branch_index, nseg_per_branch):
"""Return location corresponding to global compartment index."""
cumsum_nseg = cumsum_leading_zero(nseg_per_branch)
index = global_comp_index - cumsum_nseg[global_branch_index]
nseg = nseg_per_branch[global_branch_index]
return (0.5 + index) / nseg


def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax

from jaxley.utils.cell_utils import index_of_loc
from jaxley.utils.cell_utils import local_index_of_loc

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
Expand Down Expand Up @@ -55,7 +55,9 @@ def test_connect():
connect(net2[1, 0], net2[2, 0], TestSynapse())

# test after all connections are made, to catch "overwritten" connections
get_comps = lambda locs: [index_of_loc(0, idx, net2.nseg) for idx in locs]
get_comps = lambda locs: [
local_index_of_loc(loc, 0, net2.nseg_per_branch) for loc in locs
]

# check if all connections are made correctly
first_set_edges = net2.edges.iloc[:8]
Expand Down
2 changes: 0 additions & 2 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,10 @@ def test_subclassing_groups_net_set_equivalence():
net1.excitatory.cell([0, 3]).branch(0).comp("all").set("radius", 0.14)
net1.excitatory.cell([0, 5]).branch(1).comp("all").set("length", 0.16)
net1.excitatory.cell("all").branch(1).comp(2).set("axial_resistivity", 1100.0)
net1.excitatory.cell("all").branch(1).loc(0.0).set("axial_resistivity", 1300.0)

net2.cell([0, 3]).branch(0).comp("all").set("radius", 0.14)
net2.cell([0, 5]).branch(1).comp("all").set("length", 0.16)
net2.cell([0, 3, 5]).branch(1).comp(2).set("axial_resistivity", 1100.0)
net2.cell([0, 3, 5]).branch(1).loc(0.0).set("axial_resistivity", 1300.0)

assert all(net1.nodes == net2.nodes)

Expand Down
41 changes: 38 additions & 3 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import jaxley as jx
from jaxley.channels import HH
from jaxley.utils.cell_utils import index_of_loc, loc_of_index
from jaxley.utils.cell_utils import loc_of_index, local_index_of_loc
from jaxley.utils.misc_utils import childview


Expand Down Expand Up @@ -56,11 +56,18 @@ def test_loc_v_comp():
comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(4)])

cum_nseg = branch.cumsum_nseg
nsegs = branch.nseg_per_branch
branch_ind = 0

assert np.all(branch.comp(0).show() == branch.loc(0.0).show())
assert np.all(branch.comp(3).show() == branch.loc(1.0).show())

assert np.all(branch.loc(loc_of_index(2, 4)).show() == branch.comp(2).show())
assert np.all(branch.comp(index_of_loc(0, 0.4, 4)).show() == branch.loc(0.4).show())
inferred_loc = loc_of_index(2, branch_ind, nsegs)
assert np.all(branch.loc(inferred_loc).show() == branch.comp(2).show())

inferred_ind = local_index_of_loc(0.4, branch_ind, nsegs)
assert np.all(branch.comp(inferred_ind).show() == branch.loc(0.4).show())


def test_shape():
Expand Down Expand Up @@ -228,3 +235,31 @@ def test_comp_indexing_exception_handling():
branch.loc(0.0).comp(0)
with pytest.raises(AttributeError):
branch.loc(0.0).loc(0.0)


def test_indexing_a_compartment_of_many_branches():
comp = jx.Compartment()
branch1 = jx.Branch(comp, nseg=3)
branch2 = jx.Branch(comp, nseg=4)
branch3 = jx.Branch(comp, nseg=5)
cell1 = jx.Cell([branch1, branch2, branch3], parents=[-1, 0, 0])
cell2 = jx.Cell([branch3, branch2], parents=[-1, 0])
net = jx.Network([cell1, cell2])

# Indexing a single compartment of multiple branches is not supported with `loc`.
with pytest.raises(NotImplementedError):
net.cell("all").branch("all").loc(0.0)
with pytest.raises(NotImplementedError):
net.cell(0).branch("all").loc(0.0)
with pytest.raises(NotImplementedError):
net.cell("all").branch(0).loc(0.0)

# Indexing a single compartment of multiple branches is still supported with `comp`.
net.cell("all").branch("all").comp(0)
net.cell(0).branch("all").comp(0)
net.cell("all").branch(0).comp(0)

# Indexing many single compartment of multiple branches is always supported.
net.cell("all").branch("all").loc("all")
net.cell(0).branch("all").loc("all")
net.cell("all").branch(0).loc("all")

0 comments on commit e78bb26

Please sign in to comment.