Skip to content

Commit

Permalink
Allow changing nseg after initialization (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler authored Oct 4, 2024
1 parent 4fc75a1 commit dfd5087
Show file tree
Hide file tree
Showing 8 changed files with 542 additions and 53 deletions.
145 changes: 144 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -32,9 +33,14 @@
v_interp,
)
from jaxley.utils.debug_solver import compute_morphology_indices
from jaxley.utils.misc_utils import childview, concat_and_ignore_empty
from jaxley.utils.misc_utils import (
childview,
concat_and_ignore_empty,
cumsum_leading_zero,
)
from jaxley.utils.plot_utils import plot_comps, plot_morph
from jaxley.utils.solver_utils import convert_to_csc
from jaxley.utils.swc import build_radiuses_from_xyzr


class Module(ABC):
Expand Down Expand Up @@ -109,6 +115,7 @@ def __init__(self):

# x, y, z coordinates and radius.
self.xyzr: List[np.ndarray] = []
self._radius_generating_fns = None # Defined by `.read_swc()`.

# For debugging the solver. Will be empty by default and only filled if
# `self._init_morph_for_debugging` is run.
Expand Down Expand Up @@ -157,6 +164,14 @@ def __dir__(self):
base_dir = object.__dir__(self)
return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))

@property
def _module_type(self):
"""Return type of the module (compartment, branch, cell, network) as string.
This is used to perform asserts for some modules (e.g. network cannot use
`set_ncomp`) without having to import the module in `base.py`."""
return self.__class__.__name__.lower()

def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):
"""Insert the default params of the module (e.g. radius, length).
Expand Down Expand Up @@ -395,6 +410,134 @@ def _data_set(
raise KeyError("Key not recognized.")
return param_state

def _set_ncomp(
self,
ncomp: int,
view: pd.DataFrame,
all_nodes: pd.DataFrame,
start_idx: int,
nseg_per_branch: jnp.asarray,
channel_names: List[str],
channel_param_names: List[str],
channel_state_names: List[str],
radius_generating_fns: List[Callable],
min_radius: Optional[float],
):
"""Set the number of compartments with which the branch is discretized."""
within_branch_radiuses = view["radius"].to_numpy()
compartment_lengths = view["length"].to_numpy()
num_previous_ncomp = len(within_branch_radiuses)
branch_indices = pd.unique(view["branch_index"])

error_msg = lambda name: (
f"You previously modified the {name} of individual compartments, but "
f"now you are modifying the number of compartments in this branch. "
f"This is not allowed. First build the morphology with `set_ncomp()` and "
f"then modify the radiuses and lengths of compartments."
)

if (
~np.all(within_branch_radiuses == within_branch_radiuses[0])
and radius_generating_fns is None
):
raise ValueError(error_msg("radius"))

for property_name in ["length", "capacitance", "axial_resistivity"]:
compartment_properties = view[property_name].to_numpy()
if ~np.all(compartment_properties == compartment_properties[0]):
raise ValueError(error_msg(property_name))

if not (view[channel_names].var() == 0.0).all():
raise ValueError(
"Some channel exists only in some compartments of the branch which you"
"are trying to modify. This is not allowed. First specify the number"
"of compartments with `.set_ncomp()` and then insert the channels"
"accordingly."
)

if not (view[channel_param_names + channel_state_names].var() == 0.0).all():
raise ValueError(
"Some channel has different parameters or states between the "
"different compartments of the branch which you are trying to modify. "
"This is not allowed. First specify the number of compartments with "
"`.set_ncomp()` and then insert the channels accordingly."
)

# Add new rows as the average of all rows. Special case for the length is below.
average_row = view.mean(skipna=False)
average_row = average_row.to_frame().T
view = pd.concat([*[average_row] * ncomp], axis="rows")

# If the `view` is not the entire `Module`, but a `View` (i.e. if one changes
# the number of comps within a branch of a cell), then the `self.pointer.view`
# will contain the additional `global_xyz_index` columns. However, the
# `self.nodes` will not have these columns.
#
# Note that we assert that there are no trainables, so `controlled_by_params`
# of the `self.nodes` has to be empty.
if "global_comp_index" in view.columns:
view = view.drop(
columns=[
"global_comp_index",
"global_branch_index",
"global_cell_index",
"controlled_by_param",
]
)

# Set the correct datatype after having performed an average which cast
# everything to float.
integer_cols = ["comp_index", "branch_index", "cell_index"]
view[integer_cols] = view[integer_cols].astype(int)

# Whether or not a channel exists in a compartment is a boolean.
boolean_cols = channel_names
view[boolean_cols] = view[boolean_cols].astype(bool)

# Special treatment for the lengths and radiuses. These are not being set as
# the average because we:
# 1) Want to maintain the total length of a branch.
# 2) Want to use the SWC inferred radius.
#
# Compute new compartment lengths.
comp_lengths = np.sum(compartment_lengths) / ncomp
view["length"] = comp_lengths

# Compute new compartment radiuses.
if radius_generating_fns is not None:
view["radius"] = build_radiuses_from_xyzr(
radius_fns=radius_generating_fns,
branch_indices=branch_indices,
min_radius=min_radius,
nseg=ncomp,
)
else:
view["radius"] = within_branch_radiuses[0] * np.ones(ncomp)

# Update `.nodes`.
#
# 1) Delete N rows starting from start_idx
number_deleted = num_previous_ncomp
all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))

# 2) Insert M new rows at the same location
df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point
df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point

# 3) Combine the parts: before, new rows, and after
all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)

# Override `comp_index` to just be a consecutive list.
all_nodes["comp_index"] = np.arange(len(all_nodes))

# Update compartment structure arguments.
nseg_per_branch = nseg_per_branch.at[branch_indices].set(ncomp)
nseg = int(jnp.max(nseg_per_branch))
cumsum_nseg = cumsum_leading_zero(nseg_per_branch)
internal_node_inds = np.arange(cumsum_nseg[-1])

return all_nodes, nseg_per_branch, nseg, cumsum_nseg, internal_node_inds

def make_trainable(
self,
key: str,
Expand Down
105 changes: 104 additions & 1 deletion jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from copy import deepcopy
from itertools import chain
from typing import Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
compartment_list = compartments

self.nseg = len(compartment_list)
self.nseg_per_branch = [self.nseg]
self.nseg_per_branch = jnp.asarray([self.nseg])
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
Expand Down Expand Up @@ -146,6 +147,51 @@ def _init_morph_jax_spsolve(self):
def __len__(self) -> int:
return self.nseg

def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None):
"""Set the number of compartments with which the branch is discretized.
Args:
ncomp: The number of compartments that the branch should be discretized
into.
Raises:
- When the Module is a Network.
- When there are stimuli in any compartment in the Module.
- When there are recordings in any compartment in the Module.
- When the channels of the compartments are not the same within the branch
that is modified.
- When the lengths of the compartments are not the same within the branch
that is modified.
- Unless the morphology was read from an SWC file, when the radiuses of the
compartments are not the same within the branch that is modified.
"""
assert len(self.externals) == 0, "No stimuli allowed!"
assert len(self.recordings) == 0, "No recordings allowed!"
assert len(self.trainable_params) == 0, "No trainables allowed!"

# Update all attributes that are affected by compartment structure.
(
self.nodes,
self.nseg_per_branch,
self.nseg,
self.cumsum_nseg,
self._internal_node_inds,
) = self._set_ncomp(
ncomp,
self.nodes,
self.nodes,
self.nodes["comp_index"].to_numpy()[0],
self.nseg_per_branch,
[c._name for c in self.channels],
list(chain(*[c.channel_params for c in self.channels])),
list(chain(*[c.channel_states for c in self.channels])),
self._radius_generating_fns,
min_radius,
)

# Update the morphology indexing (e.g., `.comp_edges`).
self.initialize()


class BranchView(View):
"""BranchView."""
Expand All @@ -167,3 +213,60 @@ def __getattr__(self, key):
assert key in ["comp", "loc"]
compview = CompartmentView(self.pointer, self.view)
return compview if key == "comp" else compview.loc

def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None):
"""Set the number of compartments with which the branch is discretized.
Args:
ncomp: The number of compartments that the branch should be discretized
into.
min_radius: Only used if the morphology was read from an SWC file. If passed
the radius is capped to be at least this value.
Raises:
- When there are stimuli in any compartment in the module.
- When there are recordings in any compartment in the module.
- When the channels of the compartments are not the same within the branch
that is modified.
- When the lengths of the compartments are not the same within the branch
that is modified.
- Unless the morphology was read from an SWC file, when the radiuses of the
compartments are not the same within the branch that is modified.
"""
if self.pointer._module_type == "network":
raise NotImplementedError(
"`.set_ncomp` is not yet supported for a `Network`. To overcome this, "
"first build individual cells with the desired `ncomp` and then "
"assemble them into a network."
)

error_msg = lambda name: (
f"Your module contains a {name}. This is not allowed. First build the "
"morphology with `set_ncomp()` and then insert stimuli, recordings, and "
"define trainables."
)
assert len(self.pointer.externals) == 0, error_msg("stimulus")
assert len(self.pointer.recordings) == 0, error_msg("recording")
assert len(self.pointer.trainable_params) == 0, error_msg("trainable parameter")
# Update all attributes that are affected by compartment structure.
(
self.pointer.nodes,
self.pointer.nseg_per_branch,
self.pointer.nseg,
self.pointer.cumsum_nseg,
self.pointer._internal_node_inds,
) = self.pointer._set_ncomp(
ncomp,
self.view,
self.pointer.nodes,
self.view["global_comp_index"].to_numpy()[0],
self.pointer.nseg_per_branch,
[c._name for c in self.pointer.channels],
list(chain(*[c.channel_params for c in self.pointer.channels])),
list(chain(*[c.channel_states for c in self.pointer.channels])),
self.pointer._radius_generating_fns,
min_radius,
)

# Update the morphology indexing (e.g., `.comp_edges`).
self.pointer.initialize()
Loading

0 comments on commit dfd5087

Please sign in to comment.