From dfd5087728bf4c5ac3086605f49ab3300cdb3639 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Fri, 4 Oct 2024 15:26:48 +0200 Subject: [PATCH] Allow changing nseg after initialization (#436) --- jaxley/modules/base.py | 145 ++++++++++++++++++++++++- jaxley/modules/branch.py | 105 +++++++++++++++++- jaxley/modules/cell.py | 86 ++++++++------- jaxley/modules/compartment.py | 2 +- jaxley/modules/network.py | 21 ++-- jaxley/utils/misc_utils.py | 8 ++ jaxley/utils/swc.py | 34 ++++++ tests/test_set_ncomp.py | 194 ++++++++++++++++++++++++++++++++++ 8 files changed, 542 insertions(+), 53 deletions(-) create mode 100644 tests/test_set_ncomp.py diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c3e6a34c..cf2dfa3f 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from typing import Callable, Dict, List, Optional, Tuple, Union +from warnings import warn import jax.numpy as jnp import numpy as np @@ -32,9 +33,14 @@ v_interp, ) from jaxley.utils.debug_solver import compute_morphology_indices -from jaxley.utils.misc_utils import childview, concat_and_ignore_empty +from jaxley.utils.misc_utils import ( + childview, + concat_and_ignore_empty, + cumsum_leading_zero, +) from jaxley.utils.plot_utils import plot_comps, plot_morph from jaxley.utils.solver_utils import convert_to_csc +from jaxley.utils.swc import build_radiuses_from_xyzr class Module(ABC): @@ -109,6 +115,7 @@ def __init__(self): # x, y, z coordinates and radius. self.xyzr: List[np.ndarray] = [] + self._radius_generating_fns = None # Defined by `.read_swc()`. # For debugging the solver. Will be empty by default and only filled if # `self._init_morph_for_debugging` is run. @@ -157,6 +164,14 @@ def __dir__(self): base_dir = object.__dir__(self) return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys())) + @property + def _module_type(self): + """Return type of the module (compartment, branch, cell, network) as string. + + This is used to perform asserts for some modules (e.g. network cannot use + `set_ncomp`) without having to import the module in `base.py`.""" + return self.__class__.__name__.lower() + def _append_params_and_states(self, param_dict: Dict, state_dict: Dict): """Insert the default params of the module (e.g. radius, length). @@ -395,6 +410,134 @@ def _data_set( raise KeyError("Key not recognized.") return param_state + def _set_ncomp( + self, + ncomp: int, + view: pd.DataFrame, + all_nodes: pd.DataFrame, + start_idx: int, + nseg_per_branch: jnp.asarray, + channel_names: List[str], + channel_param_names: List[str], + channel_state_names: List[str], + radius_generating_fns: List[Callable], + min_radius: Optional[float], + ): + """Set the number of compartments with which the branch is discretized.""" + within_branch_radiuses = view["radius"].to_numpy() + compartment_lengths = view["length"].to_numpy() + num_previous_ncomp = len(within_branch_radiuses) + branch_indices = pd.unique(view["branch_index"]) + + error_msg = lambda name: ( + f"You previously modified the {name} of individual compartments, but " + f"now you are modifying the number of compartments in this branch. " + f"This is not allowed. First build the morphology with `set_ncomp()` and " + f"then modify the radiuses and lengths of compartments." + ) + + if ( + ~np.all(within_branch_radiuses == within_branch_radiuses[0]) + and radius_generating_fns is None + ): + raise ValueError(error_msg("radius")) + + for property_name in ["length", "capacitance", "axial_resistivity"]: + compartment_properties = view[property_name].to_numpy() + if ~np.all(compartment_properties == compartment_properties[0]): + raise ValueError(error_msg(property_name)) + + if not (view[channel_names].var() == 0.0).all(): + raise ValueError( + "Some channel exists only in some compartments of the branch which you" + "are trying to modify. This is not allowed. First specify the number" + "of compartments with `.set_ncomp()` and then insert the channels" + "accordingly." + ) + + if not (view[channel_param_names + channel_state_names].var() == 0.0).all(): + raise ValueError( + "Some channel has different parameters or states between the " + "different compartments of the branch which you are trying to modify. " + "This is not allowed. First specify the number of compartments with " + "`.set_ncomp()` and then insert the channels accordingly." + ) + + # Add new rows as the average of all rows. Special case for the length is below. + average_row = view.mean(skipna=False) + average_row = average_row.to_frame().T + view = pd.concat([*[average_row] * ncomp], axis="rows") + + # If the `view` is not the entire `Module`, but a `View` (i.e. if one changes + # the number of comps within a branch of a cell), then the `self.pointer.view` + # will contain the additional `global_xyz_index` columns. However, the + # `self.nodes` will not have these columns. + # + # Note that we assert that there are no trainables, so `controlled_by_params` + # of the `self.nodes` has to be empty. + if "global_comp_index" in view.columns: + view = view.drop( + columns=[ + "global_comp_index", + "global_branch_index", + "global_cell_index", + "controlled_by_param", + ] + ) + + # Set the correct datatype after having performed an average which cast + # everything to float. + integer_cols = ["comp_index", "branch_index", "cell_index"] + view[integer_cols] = view[integer_cols].astype(int) + + # Whether or not a channel exists in a compartment is a boolean. + boolean_cols = channel_names + view[boolean_cols] = view[boolean_cols].astype(bool) + + # Special treatment for the lengths and radiuses. These are not being set as + # the average because we: + # 1) Want to maintain the total length of a branch. + # 2) Want to use the SWC inferred radius. + # + # Compute new compartment lengths. + comp_lengths = np.sum(compartment_lengths) / ncomp + view["length"] = comp_lengths + + # Compute new compartment radiuses. + if radius_generating_fns is not None: + view["radius"] = build_radiuses_from_xyzr( + radius_fns=radius_generating_fns, + branch_indices=branch_indices, + min_radius=min_radius, + nseg=ncomp, + ) + else: + view["radius"] = within_branch_radiuses[0] * np.ones(ncomp) + + # Update `.nodes`. + # + # 1) Delete N rows starting from start_idx + number_deleted = num_previous_ncomp + all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted)) + + # 2) Insert M new rows at the same location + df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point + df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point + + # 3) Combine the parts: before, new rows, and after + all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True) + + # Override `comp_index` to just be a consecutive list. + all_nodes["comp_index"] = np.arange(len(all_nodes)) + + # Update compartment structure arguments. + nseg_per_branch = nseg_per_branch.at[branch_indices].set(ncomp) + nseg = int(jnp.max(nseg_per_branch)) + cumsum_nseg = cumsum_leading_zero(nseg_per_branch) + internal_node_inds = np.arange(cumsum_nseg[-1]) + + return all_nodes, nseg_per_branch, nseg, cumsum_nseg, internal_node_inds + def make_trainable( self, key: str, diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index 3933c929..c8c3ab56 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -2,6 +2,7 @@ # licensed under the Apache License Version 2.0, see from copy import deepcopy +from itertools import chain from typing import Callable, Dict, List, Optional, Tuple, Union import jax.numpy as jnp @@ -56,7 +57,7 @@ def __init__( compartment_list = compartments self.nseg = len(compartment_list) - self.nseg_per_branch = [self.nseg] + self.nseg_per_branch = jnp.asarray([self.nseg]) self.total_nbranches = 1 self.nbranches_per_cell = [1] self.cumsum_nbranches = jnp.asarray([0, 1]) @@ -146,6 +147,51 @@ def _init_morph_jax_spsolve(self): def __len__(self) -> int: return self.nseg + def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None): + """Set the number of compartments with which the branch is discretized. + + Args: + ncomp: The number of compartments that the branch should be discretized + into. + + Raises: + - When the Module is a Network. + - When there are stimuli in any compartment in the Module. + - When there are recordings in any compartment in the Module. + - When the channels of the compartments are not the same within the branch + that is modified. + - When the lengths of the compartments are not the same within the branch + that is modified. + - Unless the morphology was read from an SWC file, when the radiuses of the + compartments are not the same within the branch that is modified. + """ + assert len(self.externals) == 0, "No stimuli allowed!" + assert len(self.recordings) == 0, "No recordings allowed!" + assert len(self.trainable_params) == 0, "No trainables allowed!" + + # Update all attributes that are affected by compartment structure. + ( + self.nodes, + self.nseg_per_branch, + self.nseg, + self.cumsum_nseg, + self._internal_node_inds, + ) = self._set_ncomp( + ncomp, + self.nodes, + self.nodes, + self.nodes["comp_index"].to_numpy()[0], + self.nseg_per_branch, + [c._name for c in self.channels], + list(chain(*[c.channel_params for c in self.channels])), + list(chain(*[c.channel_states for c in self.channels])), + self._radius_generating_fns, + min_radius, + ) + + # Update the morphology indexing (e.g., `.comp_edges`). + self.initialize() + class BranchView(View): """BranchView.""" @@ -167,3 +213,60 @@ def __getattr__(self, key): assert key in ["comp", "loc"] compview = CompartmentView(self.pointer, self.view) return compview if key == "comp" else compview.loc + + def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None): + """Set the number of compartments with which the branch is discretized. + + Args: + ncomp: The number of compartments that the branch should be discretized + into. + min_radius: Only used if the morphology was read from an SWC file. If passed + the radius is capped to be at least this value. + + Raises: + - When there are stimuli in any compartment in the module. + - When there are recordings in any compartment in the module. + - When the channels of the compartments are not the same within the branch + that is modified. + - When the lengths of the compartments are not the same within the branch + that is modified. + - Unless the morphology was read from an SWC file, when the radiuses of the + compartments are not the same within the branch that is modified. + """ + if self.pointer._module_type == "network": + raise NotImplementedError( + "`.set_ncomp` is not yet supported for a `Network`. To overcome this, " + "first build individual cells with the desired `ncomp` and then " + "assemble them into a network." + ) + + error_msg = lambda name: ( + f"Your module contains a {name}. This is not allowed. First build the " + "morphology with `set_ncomp()` and then insert stimuli, recordings, and " + "define trainables." + ) + assert len(self.pointer.externals) == 0, error_msg("stimulus") + assert len(self.pointer.recordings) == 0, error_msg("recording") + assert len(self.pointer.trainable_params) == 0, error_msg("trainable parameter") + # Update all attributes that are affected by compartment structure. + ( + self.pointer.nodes, + self.pointer.nseg_per_branch, + self.pointer.nseg, + self.pointer.cumsum_nseg, + self.pointer._internal_node_inds, + ) = self.pointer._set_ncomp( + ncomp, + self.view, + self.pointer.nodes, + self.view["global_comp_index"].to_numpy()[0], + self.pointer.nseg_per_branch, + [c._name for c in self.pointer.channels], + list(chain(*[c.channel_params for c in self.pointer.channels])), + list(chain(*[c.channel_states for c in self.pointer.channels])), + self.pointer._radius_generating_fns, + min_radius, + ) + + # Update the morphology indexing (e.g., `.comp_edges`). + self.pointer.initialize() diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 169fca46..eb21fc4b 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -20,8 +20,9 @@ compute_morphology_indices_in_levels, compute_parents_in_level, ) +from jaxley.utils.misc_utils import cumsum_leading_zero from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked -from jaxley.utils.swc import swc_to_jaxley +from jaxley.utils.swc import build_radiuses_from_xyzr, swc_to_jaxley class Cell(Module): @@ -70,9 +71,9 @@ def __init__( parents = [-1] if parents is None else parents if isinstance(branches, Branch): - self.branch_list = [branches for _ in range(len(parents))] + branch_list = [branches for _ in range(len(parents))] else: - self.branch_list = branches + branch_list = branches if xyzr is not None: assert len(xyzr) == len(parents) @@ -85,28 +86,33 @@ def __init__( # self.xyzr at `.vis()`. self.xyzr = [float("NaN") * np.zeros((2, 4)) for _ in range(len(parents))] - self.nseg_per_branch = jnp.asarray([branch.nseg for branch in self.branch_list]) - self.nseg = int(jnp.max(self.nseg_per_branch)) - self.cumsum_nseg = jnp.concatenate( - [jnp.asarray([0]), jnp.cumsum(self.nseg_per_branch)] - ) - self.total_nbranches = len(self.branch_list) - self.nbranches_per_cell = [len(self.branch_list)] + self.total_nbranches = len(branch_list) + self.nbranches_per_cell = [len(branch_list)] self.comb_parents = jnp.asarray(parents) self.comb_children = compute_children_indices(self.comb_parents) - self.cumsum_nbranches = jnp.asarray([0, len(self.branch_list)]) + self.cumsum_nbranches = jnp.asarray([0, len(branch_list)]) - # Indexing. - self.nodes = pd.concat([c.nodes for c in self.branch_list], ignore_index=True) - self._append_params_and_states(self.cell_params, self.cell_states) + # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()` + # is run. + self.nseg_per_branch = jnp.asarray([branch.nseg for branch in branch_list]) + self.nseg = int(jnp.max(self.nseg_per_branch)) + self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) + self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) + + # Build nodes. Has to be changed when `.set_ncomp()` is run. + self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True) self.nodes["comp_index"] = np.arange(self.cumsum_nseg[-1]) self.nodes["branch_index"] = np.repeat( np.arange(self.total_nbranches), self.nseg_per_branch ).tolist() self.nodes["cell_index"] = np.repeat(0, self.cumsum_nseg[-1]).tolist() + # Appending general parameters (radius, length, r_a, cm) and channel parameters, + # as well as the states (v, and channel states). + self._append_params_and_states(self.cell_params, self.cell_states) + # Channels. - self._gather_channels_from_constituents(self.branch_list) + self._gather_channels_from_constituents(branch_list) # Synapse indexing. self.syn_edges = pd.DataFrame( @@ -123,7 +129,6 @@ def __init__( self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = ( compute_children_and_parents(self.branch_edges) ) - self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) self.initialize() self.init_syns() @@ -203,17 +208,22 @@ def _init_morph_jax_spsolve(self): """ # Edges between compartments within the branches. - # `[offset, offset, 0]` because we want to offset `source` and `sink`, but - # not `type`. self._comp_edges = pd.concat( [ - [offset, offset, 0] + branch._comp_edges - for offset, branch in zip(self.cumsum_nseg, self.branch_list) + pd.DataFrame() + .from_dict( + { + "source": list(range(cumsum_nseg, nseg - 1 + cumsum_nseg)) + + list(range(1 + cumsum_nseg, nseg + cumsum_nseg)), + "sink": list(range(1 + cumsum_nseg, nseg + cumsum_nseg)) + + list(range(cumsum_nseg, nseg - 1 + cumsum_nseg)), + } + ) + .astype(int) + for nseg, cumsum_nseg in zip(self.nseg_per_branch, self.cumsum_nseg) ] ) - # `branch_list` is not needed anymore because all information it contained is - # now also in `self._comp_edges`. - del self.branch_list + self._comp_edges["type"] = 0 # Edges from branchpoints to compartments. branchpoint_to_parent_edges = pd.DataFrame().from_dict( @@ -288,6 +298,13 @@ def update_summed_coupling_conds_jaxley_spsolve( summed_conds = summed_conds.at[par_inds, -1].add(branchpoint_conds_parents) return summed_conds + def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None): + """Raise an explict error if `set_ncomp` is set for an entire cell.""" + raise NotImplementedError( + "`cell.set_ncomp()` is not supported. Loop over all branches with " + "`for b in range(cell.total_nbranches): cell.branch(b).set_ncomp(n)`." + ) + class CellView(View): """CellView.""" @@ -354,27 +371,24 @@ def read_swc( ) nbranches = len(parents) - non_split = 1 / nseg - range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg) - comp = Compartment() branch = Branch([comp for _ in range(nseg)]) cell = Cell( [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches ) - - radiuses = np.asarray([radius_fns[b](range_) for b in range(len(parents))]) - radiuses_each = radiuses.ravel(order="C") - if min_radius is None: - assert np.all( - radiuses_each > 0.0 - ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`." - else: - radiuses_each[radiuses_each < min_radius] = min_radius + # Also save the radius generating functions in case users post-hoc modify the number + # of compartments with `.set_ncomp()`. + cell._radius_generating_fns = radius_fns lengths_each = np.repeat(pathlengths, nseg) / nseg - cell.set("length", lengths_each) + + radiuses_each = build_radiuses_from_xyzr( + radius_fns, + range(len(parents)), + min_radius, + nseg, + ) cell.set("radius", radiuses_each) # Description of SWC file format: diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index 6cd24f9f..8c459681 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -37,7 +37,7 @@ def __init__(self): super().__init__() self.nseg = 1 - self.nseg_per_branch = [1] + self.nseg_per_branch = jnp.asarray([1]) self.total_nbranches = 1 self.nbranches_per_cell = [1] self.cumsum_nbranches = jnp.asarray([0, 1]) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 7e356016..ea850f23 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -21,6 +21,7 @@ convert_point_process_to_distributed, merge_cells, ) +from jaxley.utils.misc_utils import cumsum_leading_zero from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked from jaxley.utils.syn_utils import gather_synapes @@ -52,14 +53,13 @@ def __init__( [cell.nseg_per_branch for cell in self.cells] ) self.nseg = int(jnp.max(self.nseg_per_branch)) - self.cumsum_nseg = jnp.concatenate( - [jnp.asarray([0]), jnp.cumsum(self.nseg_per_branch)] - ) + self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) + self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) self._append_params_and_states(self.network_params, self.network_states) self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells] self.total_nbranches = sum(self.nbranches_per_cell) - self.cumsum_nbranches = jnp.cumsum(jnp.asarray([0] + self.nbranches_per_cell)) + self.cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell) self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True) self.nodes["comp_index"] = np.arange(self.cumsum_nseg[-1]) @@ -93,13 +93,10 @@ def __init__( self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = ( compute_children_and_parents(self.branch_edges) ) - self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) # `nbranchpoints` in each cell == cell.par_inds (because `par_inds` are unique). nbranchpoints = jnp.asarray([len(cell.par_inds) for cell in self.cells]) - self.cumsum_nbranchpoints_per_cell = jnp.concatenate( - [jnp.asarray([0]), jnp.cumsum(nbranchpoints)] - ) + self.cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints) # Channels. self._gather_channels_from_constituents(cells) @@ -179,13 +176,9 @@ def _init_morph_jax_spsolve(self): `type == 3`: parent-compartment --> branchpoint `type == 4`: child-compartment --> branchpoint """ - self._cumsum_nseg_per_cell = jnp.concatenate( - [ - jnp.asarray([0]), - jnp.cumsum(jnp.asarray([cell.cumsum_nseg[-1] for cell in self.cells])), - ] + self._cumsum_nseg_per_cell = cumsum_leading_zero( + jnp.asarray([cell.cumsum_nseg[-1] for cell in self.cells]) ) - self._comp_edges = pd.DataFrame() # Add all the internal nodes. diff --git a/jaxley/utils/misc_utils.py b/jaxley/utils/misc_utils.py index 86094e03..535b11c7 100644 --- a/jaxley/utils/misc_utils.py +++ b/jaxley/utils/misc_utils.py @@ -3,6 +3,7 @@ from typing import List, Optional, Union +import jax.numpy as jnp import numpy as np import pandas as pd @@ -33,3 +34,10 @@ def childview( if child_name != "/": return module.__getattr__(child_name)(index) raise AttributeError("Compartment does not support indexing") + + +def cumsum_leading_zero(array: Union[jnp.ndarray, List]) -> jnp.ndarray: + """Return the `cumsum` of a jax array and pad with a leading zero.""" + return jnp.concatenate([jnp.asarray([0]), jnp.cumsum(jnp.asarray(array))]).astype( + int + ) diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py index d559f764..8659d418 100644 --- a/jaxley/utils/swc.py +++ b/jaxley/utils/swc.py @@ -5,6 +5,7 @@ from typing import Callable, List, Optional, Tuple from warnings import warn +import jax.numpy as jnp import numpy as np @@ -318,3 +319,36 @@ def _compute_pathlengths( dists = np.asarray([2 * radius]) branch_pathlengths.append(dists) return branch_pathlengths + + +def build_radiuses_from_xyzr( + radius_fns: List[Callable], + branch_indices: List[int], + min_radius: Optional[float], + nseg: int, +) -> jnp.ndarray: + """Return the radiuses of branches given SWC file xyzr. + + Returns an array of shape `(num_branches, nseg)`. + + Args: + radius_fns: Functions which, given compartment locations return the radius. + branch_indices: The indices of the branches for which to return the radiuses. + min_radius: If passed, the radiuses are clipped to be at least as large. + nseg: The number of compartments that every branch is discretized into. + """ + # Compartment locations are at the center of the internal nodes. + non_split = 1 / nseg + range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg) + + # Build radiuses. + radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices]) + radiuses_each = radiuses.ravel(order="C") + if min_radius is None: + assert np.all( + radiuses_each > 0.0 + ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`." + else: + radiuses_each[radiuses_each < min_radius] = min_radius + + return radiuses_each diff --git a/tests/test_set_ncomp.py b/tests/test_set_ncomp.py new file mode 100644 index 00000000..27b33b63 --- /dev/null +++ b/tests/test_set_ncomp.py @@ -0,0 +1,194 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +import os + +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp +import numpy as np +import pytest + +import jaxley as jx +from jaxley.channels import HH + + +@pytest.mark.parametrize( + "property", ["radius", "capacitance", "length", "axial_resistivity"] +) +def test_raise_for_heterogenous_modules(property): + comp = jx.Compartment() + branch0 = jx.Branch(comp, nseg=4) + branch1 = jx.Branch(comp, nseg=4) + branch1.comp(1).set(property, 1.5) + cell = jx.Cell([branch0, branch1], parents=[-1, 0]) + with pytest.raises(ValueError): + cell.branch(1).set_ncomp(2) + + +def test_raise_for_heterogenous_channel_existance(): + comp = jx.Compartment() + branch0 = jx.Branch(comp, nseg=4) + branch1 = jx.Branch(comp, nseg=4) + branch1.comp(2).insert(HH()) + cell = jx.Cell([branch0, branch1], parents=[-1, 0]) + with pytest.raises(ValueError): + cell.branch(1).set_ncomp(2) + + +def test_raise_for_heterogenous_channel_properties(): + comp = jx.Compartment() + branch0 = jx.Branch(comp, nseg=4) + branch1 = jx.Branch(comp, nseg=4) + branch1.insert(HH()) + branch1.comp(3).set("HH_gNa", 0.5) + cell = jx.Cell([branch0, branch1], parents=[-1, 0]) + with pytest.raises(ValueError): + cell.branch(1).set_ncomp(2) + + +def test_raise_for_entire_cells(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=4) + cell = jx.Cell(branch, parents=[-1, 0, 0]) + with pytest.raises(NotImplementedError): + cell.set_ncomp(2) + + +def test_raise_for_networks(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=4) + cell1 = jx.Cell(branch, parents=[-1, 0, 0]) + cell2 = jx.Cell(branch, parents=[-1, 0, 0]) + net = jx.Network([cell1, cell2]) + with pytest.raises(NotImplementedError): + net.cell(0).branch(1).set_ncomp(2) + + +def test_raise_for_recording(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=4) + cell = jx.Cell(branch, parents=[-1, 0]) + cell.branch(0).comp(0).record() + with pytest.raises(AssertionError): + cell.branch(1).set_ncomp(2) + + +def test_raise_for_stimulus(): + comp = jx.Compartment() + branch = jx.Branch(comp, nseg=4) + cell = jx.Cell(branch, parents=[-1, 0]) + cell.branch(0).comp(0).stimulate(0.4 * jnp.ones(100)) + with pytest.raises(AssertionError): + cell.branch(1).set_ncomp(2) + + +@pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) +def test_simulation_accuracy_api_equivalence_init_vs_setncomp_branch(new_ncomp): + """Test whether a module built from scratch matches module built with `set_ncomp()`. + + This makes one branch, whose `ncomp` is not modified, heterogenous. + """ + comp = jx.Compartment() + branch1 = jx.Branch(comp, nseg=new_ncomp) + + # The second branch is originally instantiated to have 4 ncomp, but is later + # modified to have `new_ncomp` compartments. + branch2 = jx.Branch(comp, nseg=4) + branch2.comp("all").set("length", 10.0) + total_branch_len = 4 * 10.0 + + # Make the total branch length 40 um. + branch1.comp("all").set("length", total_branch_len / new_ncomp) + + # Adapt ncomp. + branch2.set_ncomp(new_ncomp) + + for branch in [branch1, branch2]: + branch.comp(0).stimulate(0.4 * jnp.ones(100)) + branch.comp(new_ncomp - 1).record() + + v1 = jx.integrate(branch1) + v2 = jx.integrate(branch2) + max_error = np.max(np.abs(v1 - v2)) + assert max_error < 1e-8, f"Too large voltage deviation, {max_error} > 1e-8" + + +@pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) +def test_simulation_accuracy_api_equivalence_init_vs_setncomp_cell(new_ncomp): + """Test whether a module built from scratch matches module built with `set_ncomp()`.""" + comp = jx.Compartment() + branch1 = jx.Branch(comp, nseg=new_ncomp) + + # The second branch is originally instantiated to have 4 ncomp, but is later + # modified to have `new_ncomp` compartments. + branch2 = jx.Branch(comp, nseg=4) + branch2.comp("all").set("length", 10.0) + total_branch_len = 4 * 10.0 + + # Make the total branch length 20 um. + branch1.comp("all").set("length", total_branch_len / new_ncomp) + cell1 = jx.Cell(branch1, parents=[-1, 0]) + cell2 = jx.Cell(branch2, parents=[-1, 0]) + + # Adapt ncomp. + for b in range(2): + cell2.branch(b).set_ncomp(new_ncomp) + + for cell in [cell1, cell2]: + cell.branch(0).comp(0).stimulate(0.4 * jnp.ones(100)) + cell.branch(1).comp(new_ncomp - 1).record() + + v1 = jx.integrate(cell1) + v2 = jx.integrate(cell2) + max_error = np.max(np.abs(v1 - v2)) + assert max_error < 1e-8, f"Too large voltage deviation, {max_error} > 1e-8" + + +@pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) +@pytest.mark.parametrize("file", ["morph_250.swc"]) +def test_api_equivalence_swc_lengths_and_radiuses(new_ncomp, file): + """Test if the radiuses and lenghts of an SWC morph are reconstructed correctly.""" + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", file) + + cell1 = jx.read_swc(fname, nseg=new_ncomp, max_branch_len=2000.0) + cell2 = jx.read_swc(fname, nseg=4, max_branch_len=2000.0) + + for b in range(cell2.total_nbranches): + cell2.branch(b).set_ncomp(new_ncomp) + + for property_name in ["radius", "length"]: + cell1_vals = cell1.nodes[property_name].to_numpy() + cell2_vals = cell2.nodes[property_name].to_numpy() + assert np.allclose( + cell1_vals, cell2_vals + ), f"Too large difference in {property_name}" + + +@pytest.mark.parametrize("new_ncomp", [1, 2, 4, 5, 8]) +@pytest.mark.parametrize("file", ["morph_250.swc"]) +def test_simulation_accuracy_swc_init_vs_set_ncomp(new_ncomp, file): + """Test whether an SWC initially built with 4 ncomp works after `set_ncomp()`.""" + dirname = os.path.dirname(__file__) + fname = os.path.join(dirname, "swc_files", file) + + cell1 = jx.read_swc(fname, nseg=new_ncomp, max_branch_len=2000.0) + cell2 = jx.read_swc(fname, nseg=4, max_branch_len=2000.0) + + for b in range(cell2.total_nbranches): + cell2.branch(b).set_ncomp(new_ncomp) + + for cell in [cell1, cell2]: + cell.branch(0).comp(0).stimulate(0.4 * jnp.ones(100)) + cell.branch(0).comp(new_ncomp - 1).record() + cell.branch(3).comp(0).record() + cell.branch(5).comp(new_ncomp - 1).record() + + v1 = jx.integrate(cell1, voltage_solver="jax.sparse") + v2 = jx.integrate(cell2, voltage_solver="jax.sparse") + max_error = np.max(np.abs(v1 - v2)) + assert max_error < 1e-8, f"Too large voltage deviation, {max_error} > 1e-8"