diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index a41f965a..91901fc4 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -82,7 +82,6 @@ def __init__(self): self.cumsum_nbranches: Optional[jnp.ndarray] = None self.comb_parents: jnp.ndarray = jnp.asarray([-1]) - self.comb_branches_in_each_level: List[jnp.ndarray] = [jnp.asarray([0])] self.initialized_morph: bool = False self.initialized_syns: bool = False @@ -534,8 +533,8 @@ def _set_ncomp( all_nodes["comp_index"] = np.arange(len(all_nodes)) # Update compartment structure arguments. - nseg_per_branch = nseg_per_branch.at[branch_indices].set(ncomp) - nseg = int(jnp.max(nseg_per_branch)) + nseg_per_branch[branch_indices] = ncomp + nseg = int(np.max(nseg_per_branch)) cumsum_nseg = cumsum_leading_zero(nseg_per_branch) internal_node_inds = np.arange(cumsum_nseg[-1]) @@ -1141,17 +1140,12 @@ def step( "sinks": np.asarray(self._comp_edges["sink"].to_list()), "sources": np.asarray(self._comp_edges["source"].to_list()), "types": np.asarray(self._comp_edges["type"].to_list()), - "masked_node_inds": self._remapped_node_indices, "nseg_per_branch": self.nseg_per_branch, - "nseg": self.nseg, "par_inds": self.par_inds, "child_inds": self.child_inds, "nbranches": self.total_nbranches, "solver": voltage_solver, - "children_in_level": self.children_in_level, - "parents_in_level": self.parents_in_level, - "root_inds": self.root_inds, - "branchpoint_group_inds": self.branchpoint_group_inds, + "idx": self.solve_indexer, "debug_states": self.debug_states, } ) diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index acadcd72..7fe35e2c 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -13,7 +13,7 @@ from jaxley.modules.compartment import Compartment, CompartmentView from jaxley.utils.cell_utils import compute_children_and_parents from jaxley.utils.misc_utils import cumsum_leading_zero -from jaxley.utils.solver_utils import comp_edges_to_indices +from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices class Branch(Module): @@ -58,7 +58,7 @@ def __init__( compartment_list = compartments self.nseg = len(compartment_list) - self.nseg_per_branch = jnp.asarray([self.nseg]) + self.nseg_per_branch = np.asarray([self.nseg]) self.total_nbranches = 1 self.nbranches_per_cell = [1] self.cumsum_nbranches = jnp.asarray([0, 1]) @@ -117,11 +117,14 @@ def __getattr__(self, key: str): raise KeyError(f"Key {key} not recognized.") def _init_morph_jaxley_spsolve(self): - self.branchpoint_group_inds = np.asarray([]).astype(int) - self.root_inds = jnp.asarray([0]) - self._remapped_node_indices = self._internal_node_inds - self.children_in_level = [] - self.parents_in_level = [] + self.solve_indexer = JaxleySolveIndexer( + cumsum_nseg=self.cumsum_nseg, + branchpoint_group_inds=np.asarray([]).astype(int), + remapped_node_indices=self._internal_node_inds, + children_in_level=[], + parents_in_level=[], + root_inds=np.asarray([0]), + ) def _init_morph_jax_spsolve(self): """Initialize morphology for the jax sparse voltage solver. diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index eb21fc4b..961928ff 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -21,7 +21,11 @@ compute_parents_in_level, ) from jaxley.utils.misc_utils import cumsum_leading_zero -from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked +from jaxley.utils.solver_utils import ( + JaxleySolveIndexer, + comp_edges_to_indices, + remap_index_to_masked, +) from jaxley.utils.swc import build_radiuses_from_xyzr, swc_to_jaxley @@ -94,8 +98,8 @@ def __init__( # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()` # is run. - self.nseg_per_branch = jnp.asarray([branch.nseg for branch in branch_list]) - self.nseg = int(jnp.max(self.nseg_per_branch)) + self.nseg_per_branch = np.asarray([branch.nseg for branch in branch_list]) + self.nseg = int(np.max(self.nseg_per_branch)) self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) @@ -168,7 +172,7 @@ def _init_morph_jaxley_spsolve(self): self.par_inds, self.child_inds, ) - self.branchpoint_group_inds = build_branchpoint_group_inds( + branchpoint_group_inds = build_branchpoint_group_inds( len(self.par_inds), self.child_belongs_to_branchpoint, self.cumsum_nseg[-1], @@ -178,20 +182,37 @@ def _init_morph_jaxley_spsolve(self): parents_inds = children_and_parents["parents"] levels = compute_levels(parents) - self.children_in_level = compute_children_in_level(levels, children_inds) - self.parents_in_level = compute_parents_in_level( - levels, self.par_inds, parents_inds + children_in_level = compute_children_in_level(levels, children_inds) + parents_in_level = compute_parents_in_level(levels, self.par_inds, parents_inds) + levels_and_nseg = pd.DataFrame().from_dict( + { + "levels": levels, + "nsegs": self.nseg_per_branch, + } + ) + levels_and_nseg["max_nseg_in_level"] = levels_and_nseg.groupby("levels")[ + "nsegs" + ].transform("max") + padded_cumsum_nseg = cumsum_leading_zero( + levels_and_nseg["max_nseg_in_level"].to_numpy() ) - self.root_inds = jnp.asarray([0]) # Generate mapping to deal with the masking which allows using the custom # sparse solver to deal with different nseg per branch. - self._remapped_node_indices = remap_index_to_masked( + remapped_node_indices = remap_index_to_masked( self._internal_node_inds, self.nodes, - self.nseg, + padded_cumsum_nseg, self.nseg_per_branch, ) + self.solve_indexer = JaxleySolveIndexer( + cumsum_nseg=padded_cumsum_nseg, + branchpoint_group_inds=branchpoint_group_inds, + children_in_level=children_in_level, + parents_in_level=parents_in_level, + root_inds=np.asarray([0]), + remapped_node_indices=remapped_node_indices, + ) def _init_morph_jax_spsolve(self): """For morphology indexing with the `jax.sparse` voltage volver. diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index 4bfa2a44..0a00cbea 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -16,7 +16,7 @@ local_index_of_loc, ) from jaxley.utils.misc_utils import cumsum_leading_zero -from jaxley.utils.solver_utils import comp_edges_to_indices +from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices class Compartment(Module): @@ -38,7 +38,7 @@ def __init__(self): super().__init__() self.nseg = 1 - self.nseg_per_branch = jnp.asarray([1]) + self.nseg_per_branch = np.asarray([1]) self.total_nbranches = 1 self.nbranches_per_cell = [1] self.cumsum_nbranches = jnp.asarray([0, 1]) @@ -69,11 +69,14 @@ def __init__(self): self.xyzr = [float("NaN") * np.zeros((2, 4))] def _init_morph_jaxley_spsolve(self): - self.branchpoint_group_inds = np.asarray([]).astype(int) - self.root_inds = jnp.asarray([0]) - self._remapped_node_indices = self._internal_node_inds - self.children_in_level = [] - self.parents_in_level = [] + self.solve_indexer = JaxleySolveIndexer( + cumsum_nseg=self.cumsum_nseg, + branchpoint_group_inds=np.asarray([]).astype(int), + children_in_level=[], + parents_in_level=[], + root_inds=np.asarray([0]), + remapped_node_indices=self._internal_node_inds, + ) def _init_morph_jax_spsolve(self): """Initialize morphology for the jax sparse voltage solver. diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index ea850f23..615d08fa 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -13,7 +13,6 @@ from matplotlib.axes import Axes from jaxley.modules.base import GroupView, Module, View -from jaxley.modules.branch import Branch from jaxley.modules.cell import Cell, CellView from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, @@ -22,7 +21,11 @@ merge_cells, ) from jaxley.utils.misc_utils import cumsum_leading_zero -from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked +from jaxley.utils.solver_utils import ( + JaxleySolveIndexer, + comp_edges_to_indices, + remap_index_to_masked, +) from jaxley.utils.syn_utils import gather_synapes @@ -49,10 +52,10 @@ def __init__( self.xyzr += deepcopy(cell.xyzr) self.cells = cells - self.nseg_per_branch = jnp.concatenate( + self.nseg_per_branch = np.concatenate( [cell.nseg_per_branch for cell in self.cells] ) - self.nseg = int(jnp.max(self.nseg_per_branch)) + self.nseg = int(np.max(self.nseg_per_branch)) self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) self._append_params_and_states(self.network_params, self.network_states) @@ -129,33 +132,45 @@ def __getattr__(self, key: str): raise KeyError(f"Key {key} not recognized.") def _init_morph_jaxley_spsolve(self): - self.branchpoint_group_inds = build_branchpoint_group_inds( + branchpoint_group_inds = build_branchpoint_group_inds( len(self.par_inds), self.child_belongs_to_branchpoint, self.cumsum_nseg[-1], ) - self.children_in_level = merge_cells( + children_in_level = merge_cells( self.cumsum_nbranches, self.cumsum_nbranchpoints_per_cell, - [cell.children_in_level for cell in self.cells], + [cell.solve_indexer.children_in_level for cell in self.cells], exclude_first=False, ) - self.parents_in_level = merge_cells( + parents_in_level = merge_cells( self.cumsum_nbranches, self.cumsum_nbranchpoints_per_cell, - [cell.parents_in_level for cell in self.cells], + [cell.solve_indexer.parents_in_level for cell in self.cells], exclude_first=False, ) - self.root_inds = self.cumsum_nbranches[:-1] + padded_cumsum_nseg = cumsum_leading_zero( + np.concatenate( + [np.diff(cell.solve_indexer.cumsum_nseg) for cell in self.cells] + ) + ) # Generate mapping to dealing with the masking which allows using the custom # sparse solver to deal with different nseg per branch. - self._remapped_node_indices = remap_index_to_masked( + remapped_node_indices = remap_index_to_masked( self._internal_node_inds, self.nodes, - self.nseg, + padded_cumsum_nseg, self.nseg_per_branch, ) + self.solve_indexer = JaxleySolveIndexer( + cumsum_nseg=padded_cumsum_nseg, + branchpoint_group_inds=branchpoint_group_inds, + children_in_level=children_in_level, + parents_in_level=parents_in_level, + root_inds=self.cumsum_nbranches[:-1], + remapped_node_indices=remapped_node_indices, + ) def _init_morph_jax_spsolve(self): """Initialize the morphology for networks. diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 80c1538a..07738f6e 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -11,6 +11,7 @@ from tridiax.thomas import thomas_backsub_lower, thomas_triang_upper from jaxley.utils.cell_utils import group_and_sum +from jaxley.utils.solver_utils import JaxleySolveIndexer def step_voltage_explicit( @@ -22,18 +23,13 @@ def step_voltage_explicit( sinks: jnp.ndarray, sources: jnp.ndarray, types: jnp.ndarray, - masked_node_inds: jnp.ndarray, nseg_per_branch: jnp.ndarray, - nseg: int, par_inds: jnp.ndarray, child_inds: jnp.ndarray, nbranches: int, solver: str, delta_t: float, - children_in_level: List[jnp.ndarray], - parents_in_level: List[jnp.ndarray], - root_inds: jnp.ndarray, - branchpoint_group_inds: jnp.ndarray, + idx: JaxleySolveIndexer, debug_states, ) -> jnp.ndarray: """Solve one timestep of branched nerve equations with explicit (forward) Euler.""" @@ -54,10 +50,7 @@ def step_voltage_explicit( nbranches, solver, delta_t, - children_in_level, - parents_in_level, - root_inds, - branchpoint_group_inds, + idx, debug_states, ) new_voltates = voltages + delta_t * update @@ -73,37 +66,31 @@ def step_voltage_implicit_with_jaxley_spsolve( sinks: jnp.ndarray, sources: jnp.ndarray, types: jnp.ndarray, - masked_node_inds: jnp.ndarray, nseg_per_branch: jnp.ndarray, - nseg: int, par_inds: jnp.ndarray, child_inds: jnp.ndarray, nbranches: int, solver: str, delta_t: float, - children_in_level: List[jnp.ndarray], - parents_in_level: List[jnp.ndarray], - root_inds: jnp.ndarray, - branchpoint_group_inds: jnp.ndarray, + idx: JaxleySolveIndexer, debug_states, ): """Solve one timestep of branched nerve equations with implicit (backward) Euler.""" # Build diagonals. c2c = np.isin(types, [0, 1, 2]) - diags = jnp.ones(nbranches * nseg) + total_ncomp = idx.cumsum_nseg[-1] + diags = jnp.ones(total_ncomp) # if-case needed because `.at` does not allow empty inputs, but the input is # empty for compartments. if len(sinks[c2c]) > 0: - diags = diags.at[masked_node_inds[sinks[c2c]]].add( - delta_t * axial_conductances[c2c] - ) + diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c]) - diags = diags.at[masked_node_inds[internal_node_inds]].add(delta_t * voltage_terms) + diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms) # Build solves. - solves = jnp.zeros(nbranches * nseg) - solves = solves.at[masked_node_inds[internal_node_inds]].add( + solves = jnp.zeros(total_ncomp) + solves = solves.at[idx.mask(internal_node_inds)].add( voltages + delta_t * constant_terms ) @@ -111,32 +98,23 @@ def step_voltage_implicit_with_jaxley_spsolve( c2c = types == 0 # c2c = compartment-to-compartment. # Build uppers. - uppers = jnp.zeros(nbranches * nseg) + uppers = jnp.zeros(total_ncomp) upper_inds = sources[c2c] > sinks[c2c] sinks_upper = sinks[c2c][upper_inds] if len(sinks_upper) > 0: - uppers = uppers.at[masked_node_inds[sinks_upper]].add( + uppers = uppers.at[idx.mask(sinks_upper)].add( -delta_t * axial_conductances[c2c][upper_inds] ) # Build lowers. - lowers = jnp.zeros(nbranches * nseg) + lowers = jnp.zeros(total_ncomp) lower_inds = sources[c2c] < sinks[c2c] sinks_lower = sinks[c2c][lower_inds] if len(sinks_lower) > 0: - lowers = lowers.at[masked_node_inds[sinks_lower]].add( + lowers = lowers.at[idx.mask(sinks_lower)].add( -delta_t * axial_conductances[c2c][lower_inds] ) - # Reshape all diags, lowers, uppers, and solves into a "per-branch" format. - diags = jnp.reshape(diags, (nbranches, -1)) - solves = jnp.reshape(solves, (nbranches, -1)) - uppers = jnp.reshape(uppers, (nbranches, -1)) - lowers = jnp.reshape(lowers, (nbranches, -1)) - # lowers and uppers were built to have length `nseg` above for simplicity. - uppers = uppers[:, :-1] - lowers = lowers[:, 1:] - # Build branchpoint conductances. branchpoint_conds_parents = axial_conductances[types == 1] branchpoint_conds_children = axial_conductances[types == 2] @@ -148,7 +126,7 @@ def step_voltage_implicit_with_jaxley_spsolve( # Find unique group identifiers num_branchpoints = len(branchpoint_conds_parents) branchpoint_diags = -group_and_sum( - all_branchpoint_vals, branchpoint_group_inds, num_branchpoints + all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints ) branchpoint_solves = jnp.zeros((num_branchpoints,)) @@ -201,10 +179,8 @@ def step_voltage_implicit_with_jaxley_spsolve( branchpoint_diags, branchpoint_solves, solver, - children_in_level, - parents_in_level, - root_inds, nseg_per_branch, + idx, debug_states, ) @@ -228,13 +204,11 @@ def step_voltage_implicit_with_jaxley_spsolve( branchpoint_diags, branchpoint_solves, solver, - children_in_level, - parents_in_level, - root_inds, nseg_per_branch, + idx, debug_states, ) - return solves.ravel(order="C")[masked_node_inds[internal_node_inds]] + return solves.ravel(order="C")[idx.mask(internal_node_inds)] def step_voltage_implicit_with_jax_spsolve( @@ -291,10 +265,7 @@ def _voltage_vectorfield( nbranches: int, solver: str, delta_t: float, - children_in_level: List[jnp.ndarray], - parents_in_level: List[jnp.ndarray], - root_inds: jnp.ndarray, - branchpoint_group_inds: jnp.ndarray, + idx: JaxleySolveIndexer, debug_states, ) -> jnp.ndarray: """Evaluate the vectorfield of the nerve equation.""" @@ -346,16 +317,22 @@ def _triang_branched( branchpoint_diags, branchpoint_solves, tridiag_solver, - children_in_level, - parents_in_level, - root_inds, nseg_per_branch, + idx, debug_states, ): """Triangulation.""" - for cil, pil in zip(reversed(children_in_level), reversed(parents_in_level)): + for cil, pil in zip( + reversed(idx.children_in_level), reversed(idx.parents_in_level) + ): diags, lowers, solves, uppers = _triang_level( - cil[:, 0], lowers, diags, uppers, solves, tridiag_solver + cil[:, 0], + lowers, + diags, + uppers, + solves, + tridiag_solver, + idx, ) ( branchpoint_diags, @@ -369,6 +346,7 @@ def _triang_branched( branchpoint_weights_children, branchpoint_diags, branchpoint_solves, + idx, ) diags, solves, branchpoint_conds_parents = _eliminate_parents_upper( pil, @@ -379,10 +357,11 @@ def _triang_branched( branchpoint_diags, branchpoint_solves, nseg_per_branch, + idx, ) # At last level, we do not want to eliminate anymore. diags, lowers, solves, uppers = _triang_level( - root_inds, lowers, diags, uppers, solves, tridiag_solver + idx.root_inds, lowers, diags, uppers, solves, tridiag_solver, idx ) return ( diags, @@ -408,10 +387,8 @@ def _backsub_branched( branchpoint_diags, branchpoint_solves, tridiag_solver, - children_in_level, - parents_in_level, - root_inds, nseg_per_branch, + idx, debug_states, ): """ @@ -419,10 +396,15 @@ def _backsub_branched( """ # At first level, we do not want to eliminate. solves, lowers, diags = _backsub_level( - root_inds, diags, lowers, solves, tridiag_solver + idx.root_inds, + diags, + lowers, + solves, + tridiag_solver, + idx, ) counter = 0 - for cil, pil in zip(children_in_level, parents_in_level): + for cil, pil in zip(idx.children_in_level, idx.parents_in_level): branchpoint_weights_parents, branchpoint_solves = _eliminate_parents_lower( pil, diags, @@ -430,6 +412,7 @@ def _backsub_branched( branchpoint_weights_parents, branchpoint_solves, nseg_per_branch, + idx, ) branchpoint_conds_children, solves = _eliminate_children_upper( cil, @@ -437,9 +420,10 @@ def _backsub_branched( branchpoint_conds_children, branchpoint_diags, branchpoint_solves, + idx, ) solves, lowers, diags = _backsub_level( - cil[:, 0], diags, lowers, solves, tridiag_solver + cil[:, 0], diags, lowers, solves, tridiag_solver, idx ) counter += 1 return ( @@ -452,7 +436,7 @@ def _backsub_branched( ) -def _triang_level(cil, lowers, diags, uppers, solves, tridiag_solver): +def _triang_level(cil, lowers, diags, uppers, solves, tridiag_solver, idx): if tridiag_solver == "jaxley.stone": triang_fn = stone_triang_upper elif tridiag_solver == "jaxley.thomas": @@ -460,12 +444,15 @@ def _triang_level(cil, lowers, diags, uppers, solves, tridiag_solver): else: raise NameError new_diags, new_lowers, new_solves = vmap(triang_fn, in_axes=(0, 0, 0, 0))( - lowers[cil], diags[cil], uppers[cil], solves[cil] + lowers[idx.lower(cil)], + diags[idx.branch(cil)], + uppers[idx.upper(cil)], + solves[idx.branch(cil)], ) - diags = diags.at[cil].set(new_diags) - lowers = lowers.at[cil].set(new_lowers) - solves = solves.at[cil].set(new_solves) - uppers = uppers.at[cil].set(0.0) + diags = diags.at[idx.branch(cil)].set(new_diags) + lowers = lowers.at[idx.lower(cil)].set(new_lowers) + solves = solves.at[idx.branch(cil)].set(new_solves) + uppers = uppers.at[idx.upper(cil)].set(0.0) return diags, lowers, solves, uppers @@ -476,6 +463,7 @@ def _backsub_level( lowers: jnp.ndarray, solves: jnp.ndarray, tridiag_solver: str, + idx, ) -> jnp.ndarray: bil = cil if tridiag_solver == "jaxley.stone": @@ -484,11 +472,13 @@ def _backsub_level( backsub_fn = thomas_backsub_lower else: raise NameError - solves = solves.at[bil].set( - vmap(backsub_fn, in_axes=(0, 0, 0))(solves[bil], lowers[bil], diags[bil]) + solves = solves.at[idx.branch(bil)].set( + vmap(backsub_fn, in_axes=(0, 0, 0))( + solves[idx.branch(bil)], lowers[idx.lower(bil)], diags[idx.branch(bil)] + ) ) - lowers = lowers.at[bil].set(0.0) - diags = diags.at[bil].set(1.0) + lowers = lowers.at[idx.lower(bil)].set(0.0) + diags = diags.at[idx.branch(bil)].set(1.0) return solves, lowers, diags @@ -500,12 +490,13 @@ def _eliminate_children_lower( branchpoint_weights_children, branchpoint_diags, branchpoint_solves, + idx, ): bil = cil[:, 0] bpil = cil[:, 1] new_diag, new_solve = vmap(_eliminate_single_child_lower, in_axes=(0, 0, 0, 0))( - diags[bil, 0], - solves[bil, 0], + diags[idx.first(bil)], + solves[idx.first(bil)], branchpoint_conds_children[bil], branchpoint_weights_children[bil], ) @@ -537,6 +528,7 @@ def _eliminate_parents_upper( branchpoint_diags, branchpoint_solves, nseg_per_branch: jnp.ndarray, + idx, ): bil = pil[:, 0] bpil = pil[:, 1] @@ -548,8 +540,8 @@ def _eliminate_parents_upper( ) # Update the diagonal elements and `b` in `Ax=b` (called `solves`). - diags = diags.at[bil, nseg_per_branch[bil] - 1].add(new_diag) - solves = solves.at[bil, nseg_per_branch[bil] - 1].add(new_solve) + diags = diags.at[idx.last(bil)].add(new_diag) + solves = solves.at[idx.last(bil)].add(new_solve) branchpoint_conds_parents = branchpoint_conds_parents.at[bil].set(0.0) return diags, solves, branchpoint_conds_parents @@ -575,13 +567,12 @@ def _eliminate_parents_lower( branchpoint_weights_parents, branchpoint_solves, nseg_per_branch: jnp.ndarray, + idx, ): bil = pil[:, 0] bpil = pil[:, 1] branchpoint_solves = branchpoint_solves.at[bpil].add( - -solves[bil, nseg_per_branch[bil] - 1] - * branchpoint_weights_parents[bil] - / diags[bil, nseg_per_branch[bil] - 1] + -solves[idx.last(bil)] * branchpoint_weights_parents[bil] / diags[idx.last(bil)] ) branchpoint_weights_parents = branchpoint_weights_parents.at[bil].set(0.0) return branchpoint_weights_parents, branchpoint_solves @@ -593,10 +584,11 @@ def _eliminate_children_upper( branchpoint_conds_children, branchpoint_diags, branchpoint_solves, + idx, ): bil = cil[:, 0] bpil = cil[:, 1] - solves = solves.at[bil, 0].add( + solves = solves.at[idx.first(bil)].add( -branchpoint_solves[bpil] * branchpoint_conds_children[bil] / branchpoint_diags[bpil] diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index bd772b61..99babff3 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -52,9 +52,9 @@ def compute_rad(branch_ind, loc): def merge_cells( cumsum_num_branches: List[int], cumsum_num_branchpoints: List[int], - arrs: List[List[jnp.ndarray]], + arrs: List[List[np.ndarray]], exclude_first: bool = True, -) -> jnp.ndarray: +) -> np.ndarray: """ Build full list of which branches are solved in which iteration. @@ -83,7 +83,7 @@ def merge_cells( else: p = [ p_in_level - + jnp.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]]) + + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]]) for p_in_level in p ] ps.append(p) @@ -95,7 +95,7 @@ def merge_cells( for p in ps: if len(p) > i: current_ps.append(p[i]) - combined_parents_in_level.append(jnp.concatenate(current_ps)) + combined_parents_in_level.append(np.concatenate(current_ps)) return combined_parents_in_level @@ -112,8 +112,8 @@ def compute_levels(parents): def compute_children_in_level( - levels: jnp.ndarray, children_row_and_col: jnp.ndarray -) -> List[jnp.ndarray]: + levels: np.ndarray, children_row_and_col: np.ndarray +) -> List[np.ndarray]: num_branches = len(levels) children_in_each_level = [] for l in range(1, np.max(levels) + 1): @@ -121,7 +121,7 @@ def compute_children_in_level( for b in range(num_branches): if levels[b] == l: children_in_current_level.append(children_row_and_col[b - 1]) - children_in_current_level = jnp.asarray(children_in_current_level) + children_in_current_level = np.asarray(children_in_current_level) children_in_each_level.append(children_in_current_level) return children_in_each_level @@ -130,8 +130,9 @@ def compute_parents_in_level(levels, par_inds, parents_row_and_col): level_of_parent = levels[par_inds] parents_in_each_level = [] for l in range(np.max(levels)): - parents_inds_in_current_level = jnp.where(level_of_parent == l)[0] + parents_inds_in_current_level = np.where(level_of_parent == l)[0] parents_in_current_level = parents_row_and_col[parents_inds_in_current_level] + parents_in_current_level = np.asarray(parents_in_current_level) parents_in_each_level.append(parents_in_current_level) return parents_in_each_level diff --git a/jaxley/utils/misc_utils.py b/jaxley/utils/misc_utils.py index 535b11c7..92455458 100644 --- a/jaxley/utils/misc_utils.py +++ b/jaxley/utils/misc_utils.py @@ -36,8 +36,7 @@ def childview( raise AttributeError("Compartment does not support indexing") -def cumsum_leading_zero(array: Union[jnp.ndarray, List]) -> jnp.ndarray: - """Return the `cumsum` of a jax array and pad with a leading zero.""" - return jnp.concatenate([jnp.asarray([0]), jnp.cumsum(jnp.asarray(array))]).astype( - int - ) +def cumsum_leading_zero(array: Union[np.ndarray, List]) -> np.ndarray: + """Return the `cumsum` of a numpy array and pad with a leading zero.""" + arr = np.asarray(array) + return np.concatenate([np.asarray([0]), np.cumsum(arr)]).astype(arr.dtype) diff --git a/jaxley/utils/solver_utils.py b/jaxley/utils/solver_utils.py index 50d161cb..ce0102a1 100644 --- a/jaxley/utils/solver_utils.py +++ b/jaxley/utils/solver_utils.py @@ -9,7 +9,7 @@ def remap_index_to_masked( - index, nodes: pd.DataFrame, max_nseg: int, nseg_per_branch: jnp.ndarray + index, nodes: pd.DataFrame, padded_cumsum_nseg, nseg_per_branch: jnp.ndarray ): """Convert actual index of the compartment to the index in the masked system. @@ -27,7 +27,7 @@ def remap_index_to_masked( ) branch_inds = nodes.loc[index, "branch_index"].to_numpy() remainders = index - cumsum_nseg_per_branch[branch_inds] - return branch_inds * max_nseg + remainders + return padded_cumsum_nseg[branch_inds] + remainders def convert_to_csc( @@ -97,3 +97,105 @@ def comp_edges_to_indices( col_ind=all_inds[1], ) return n_nodes, data_inds, indices, indptr + + +class JaxleySolveIndexer: + """Indexer for easy access to compartment indices given a branch index. + + Used only by the custom Jaxley solvers. This class has two purposes: + + 1) It simplifies indexing. Indexing is difficult because every branch has a + different number of compartments (in the solve, every branch within a level has + the same number of compartments, but the number can still differ between levels). + + 2) It stores several attributes such that we do not have to track all of them + separately before they are used in `step()`. + """ + + def __init__( + self, + cumsum_nseg: np.ndarray, + branchpoint_group_inds: Optional[np.ndarray] = None, + children_in_level: Optional[np.ndarray] = None, + parents_in_level: Optional[np.ndarray] = None, + root_inds: Optional[np.ndarray] = None, + remapped_node_indices: Optional[np.ndarray] = None, + ): + self.cumsum_nseg = np.asarray(cumsum_nseg) + + # Save items for easier access. + self.branchpoint_group_inds = branchpoint_group_inds + self.children_in_level = children_in_level + self.parents_in_level = parents_in_level + self.root_inds = root_inds + self.remapped_node_indices = remapped_node_indices + + def first(self, branch_inds: np.ndarray) -> np.ndarray: + """Return the indices of the first compartment of all `branch_inds`.""" + return self.cumsum_nseg[branch_inds] + + def last(self, branch_inds: np.ndarray) -> np.ndarray: + """Return the indices of the last compartment of all `branch_inds`.""" + return self.cumsum_nseg[branch_inds + 1] - 1 + + def branch(self, branch_inds: np.ndarray) -> np.ndarray: + """Return indices of all compartments in all `branch_inds`.""" + start_inds = self.first(branch_inds) + end_inds = self.last(branch_inds) + 1 + return self._consecutive_indices(start_inds, end_inds) + + def lower(self, branch_inds: np.ndarray) -> np.ndarray: + """Return indices of all lowers in all `branch_inds`. + + This is needed because the `lowers` array in the voltage solve is instantiated + to have as many elements as the `diagonal`. In this method, we get rid of + this additional element.""" + start_inds = self.first(branch_inds) + 1 + end_inds = self.last(branch_inds) + 1 + return self._consecutive_indices(start_inds, end_inds) + + def upper(self, branch_inds: np.ndarray) -> np.ndarray: + """Return indices of all uppers in all `branch_inds`. + + This is needed because the `uppers` array in the voltage solve is instantiated + to have as many elements as the `diagonal`. In this method, we get rid of + this additional element.""" + start_inds = self.first(branch_inds) + end_inds = self.last(branch_inds) + return self._consecutive_indices(start_inds, end_inds) + + def _consecutive_indices( + self, start_inds: np.ndarray, end_inds: np.ndarray + ) -> np.ndarray: + """Return array of all indices in [start, end], for every start, end. + + It also reshape the indices to `(nbranches, nseg)`. + + E.g.: + ``` + start_inds = [0, 6] + end_inds = [3, 9] + --> + [[0, 1, 2], [6, 7, 8]] + ``` + """ + n_inds = end_inds - start_inds + assert np.all(n_inds[0] == n_inds), ( + "The indexer only supports indexing into branches with the same number " + "of compartments." + ) + if n_inds[0] > 0: + repeated_starts = np.reshape(np.repeat(start_inds, n_inds), (-1, n_inds[0])) + # For single compartment neurons there are no uppers or lowers, so `n_inds` + # can be zero. + return repeated_starts + np.arange(n_inds[0]).astype(int) + else: + return np.asarray([[]] * len(start_inds)).astype(int) + + def mask(self, indices: np.ndarray) -> np.ndarray: + """Return the masked index given the global compartment index. + + The masked index is the one which occurs because all branches within a level + must have the same number of compartments for the solve. + """ + return self.remapped_node_indices[indices] diff --git a/tests/test_indexing.py b/tests/test_indexing.py index cc4bb4cf..26781450 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -15,7 +15,8 @@ import jaxley as jx from jaxley.channels import HH from jaxley.utils.cell_utils import loc_of_index, local_index_of_loc -from jaxley.utils.misc_utils import childview +from jaxley.utils.misc_utils import childview, cumsum_leading_zero +from jaxley.utils.solver_utils import JaxleySolveIndexer def test_getitem(): @@ -263,3 +264,15 @@ def test_indexing_a_compartment_of_many_branches(): net.cell("all").branch("all").loc("all") net.cell(0).branch("all").loc("all") net.cell("all").branch(0).loc("all") + + +def test_solve_indexer(): + nsegs = [4, 3, 4, 2, 2, 3, 3] + cumsum_nseg = cumsum_leading_zero(nsegs) + idx = JaxleySolveIndexer(cumsum_nseg) + branch_inds = np.asarray([0, 2]) + assert np.all(idx.first(branch_inds) == np.asarray([0, 7])) + assert np.all(idx.last(branch_inds) == np.asarray([3, 10])) + assert np.all(idx.branch(branch_inds) == np.asarray([[0, 1, 2, 3], [7, 8, 9, 10]])) + assert np.all(idx.lower(branch_inds) == np.asarray([[1, 2, 3], [8, 9, 10]])) + assert np.all(idx.upper(branch_inds) == np.asarray([[0, 1, 2], [7, 8, 9]]))