Skip to content

Commit

Permalink
Use maximal ncomp per level of the solve, not the globally maximal ncomp
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Oct 10, 2024
1 parent e78bb26 commit d169f2c
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 139 deletions.
12 changes: 3 additions & 9 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(self):
self.cumsum_nbranches: Optional[jnp.ndarray] = None

self.comb_parents: jnp.ndarray = jnp.asarray([-1])
self.comb_branches_in_each_level: List[jnp.ndarray] = [jnp.asarray([0])]

self.initialized_morph: bool = False
self.initialized_syns: bool = False
Expand Down Expand Up @@ -534,8 +533,8 @@ def _set_ncomp(
all_nodes["comp_index"] = np.arange(len(all_nodes))

# Update compartment structure arguments.
nseg_per_branch = nseg_per_branch.at[branch_indices].set(ncomp)
nseg = int(jnp.max(nseg_per_branch))
nseg_per_branch[branch_indices] = ncomp
nseg = int(np.max(nseg_per_branch))
cumsum_nseg = cumsum_leading_zero(nseg_per_branch)
internal_node_inds = np.arange(cumsum_nseg[-1])

Expand Down Expand Up @@ -1141,17 +1140,12 @@ def step(
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"sources": np.asarray(self._comp_edges["source"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"masked_node_inds": self._remapped_node_indices,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"idx": self.solve_indexer,
"debug_states": self.debug_states,
}
)
Expand Down
17 changes: 10 additions & 7 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jaxley.modules.compartment import Compartment, CompartmentView
from jaxley.utils.cell_utils import compute_children_and_parents
from jaxley.utils.misc_utils import cumsum_leading_zero
from jaxley.utils.solver_utils import comp_edges_to_indices
from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices


class Branch(Module):
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
compartment_list = compartments

self.nseg = len(compartment_list)
self.nseg_per_branch = jnp.asarray([self.nseg])
self.nseg_per_branch = np.asarray([self.nseg])
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
Expand Down Expand Up @@ -117,11 +117,14 @@ def __getattr__(self, key: str):
raise KeyError(f"Key {key} not recognized.")

def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = np.asarray([]).astype(int)
self.root_inds = jnp.asarray([0])
self._remapped_node_indices = self._internal_node_inds
self.children_in_level = []
self.parents_in_level = []
self.solve_indexer = JaxleySolveIndexer(
cumsum_nseg=self.cumsum_nseg,
branchpoint_group_inds=np.asarray([]).astype(int),
remapped_node_indices=self._internal_node_inds,
children_in_level=[],
parents_in_level=[],
root_inds=np.asarray([0]),
)

def _init_morph_jax_spsolve(self):
"""Initialize morphology for the jax sparse voltage solver.
Expand Down
41 changes: 31 additions & 10 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
compute_parents_in_level,
)
from jaxley.utils.misc_utils import cumsum_leading_zero
from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked
from jaxley.utils.solver_utils import (
JaxleySolveIndexer,
comp_edges_to_indices,
remap_index_to_masked,
)
from jaxley.utils.swc import build_radiuses_from_xyzr, swc_to_jaxley


Expand Down Expand Up @@ -94,8 +98,8 @@ def __init__(

# Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`
# is run.
self.nseg_per_branch = jnp.asarray([branch.nseg for branch in branch_list])
self.nseg = int(jnp.max(self.nseg_per_branch))
self.nseg_per_branch = np.asarray([branch.nseg for branch in branch_list])
self.nseg = int(np.max(self.nseg_per_branch))
self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch)
self._internal_node_inds = np.arange(self.cumsum_nseg[-1])

Expand Down Expand Up @@ -168,7 +172,7 @@ def _init_morph_jaxley_spsolve(self):
self.par_inds,
self.child_inds,
)
self.branchpoint_group_inds = build_branchpoint_group_inds(
branchpoint_group_inds = build_branchpoint_group_inds(
len(self.par_inds),
self.child_belongs_to_branchpoint,
self.cumsum_nseg[-1],
Expand All @@ -178,20 +182,37 @@ def _init_morph_jaxley_spsolve(self):
parents_inds = children_and_parents["parents"]

levels = compute_levels(parents)
self.children_in_level = compute_children_in_level(levels, children_inds)
self.parents_in_level = compute_parents_in_level(
levels, self.par_inds, parents_inds
children_in_level = compute_children_in_level(levels, children_inds)
parents_in_level = compute_parents_in_level(levels, self.par_inds, parents_inds)
levels_and_nseg = pd.DataFrame().from_dict(
{
"levels": levels,
"nsegs": self.nseg_per_branch,
}
)
levels_and_nseg["max_nseg_in_level"] = levels_and_nseg.groupby("levels")[
"nsegs"
].transform("max")
padded_cumsum_nseg = cumsum_leading_zero(
levels_and_nseg["max_nseg_in_level"].to_numpy()
)
self.root_inds = jnp.asarray([0])

# Generate mapping to deal with the masking which allows using the custom
# sparse solver to deal with different nseg per branch.
self._remapped_node_indices = remap_index_to_masked(
remapped_node_indices = remap_index_to_masked(
self._internal_node_inds,
self.nodes,
self.nseg,
padded_cumsum_nseg,
self.nseg_per_branch,
)
self.solve_indexer = JaxleySolveIndexer(
cumsum_nseg=padded_cumsum_nseg,
branchpoint_group_inds=branchpoint_group_inds,
children_in_level=children_in_level,
parents_in_level=parents_in_level,
root_inds=np.asarray([0]),
remapped_node_indices=remapped_node_indices,
)

def _init_morph_jax_spsolve(self):
"""For morphology indexing with the `jax.sparse` voltage volver.
Expand Down
17 changes: 10 additions & 7 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
local_index_of_loc,
)
from jaxley.utils.misc_utils import cumsum_leading_zero
from jaxley.utils.solver_utils import comp_edges_to_indices
from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices


class Compartment(Module):
Expand All @@ -38,7 +38,7 @@ def __init__(self):
super().__init__()

self.nseg = 1
self.nseg_per_branch = jnp.asarray([1])
self.nseg_per_branch = np.asarray([1])
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self.cumsum_nbranches = jnp.asarray([0, 1])
Expand Down Expand Up @@ -69,11 +69,14 @@ def __init__(self):
self.xyzr = [float("NaN") * np.zeros((2, 4))]

def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = np.asarray([]).astype(int)
self.root_inds = jnp.asarray([0])
self._remapped_node_indices = self._internal_node_inds
self.children_in_level = []
self.parents_in_level = []
self.solve_indexer = JaxleySolveIndexer(
cumsum_nseg=self.cumsum_nseg,
branchpoint_group_inds=np.asarray([]).astype(int),
children_in_level=[],
parents_in_level=[],
root_inds=np.asarray([0]),
remapped_node_indices=self._internal_node_inds,
)

def _init_morph_jax_spsolve(self):
"""Initialize morphology for the jax sparse voltage solver.
Expand Down
39 changes: 27 additions & 12 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from matplotlib.axes import Axes

from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.branch import Branch
from jaxley.modules.cell import Cell, CellView
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
Expand All @@ -22,7 +21,11 @@
merge_cells,
)
from jaxley.utils.misc_utils import cumsum_leading_zero
from jaxley.utils.solver_utils import comp_edges_to_indices, remap_index_to_masked
from jaxley.utils.solver_utils import (
JaxleySolveIndexer,
comp_edges_to_indices,
remap_index_to_masked,
)
from jaxley.utils.syn_utils import gather_synapes


Expand All @@ -49,10 +52,10 @@ def __init__(
self.xyzr += deepcopy(cell.xyzr)

self.cells = cells
self.nseg_per_branch = jnp.concatenate(
self.nseg_per_branch = np.concatenate(
[cell.nseg_per_branch for cell in self.cells]
)
self.nseg = int(jnp.max(self.nseg_per_branch))
self.nseg = int(np.max(self.nseg_per_branch))
self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch)
self._internal_node_inds = np.arange(self.cumsum_nseg[-1])
self._append_params_and_states(self.network_params, self.network_states)
Expand Down Expand Up @@ -129,33 +132,45 @@ def __getattr__(self, key: str):
raise KeyError(f"Key {key} not recognized.")

def _init_morph_jaxley_spsolve(self):
self.branchpoint_group_inds = build_branchpoint_group_inds(
branchpoint_group_inds = build_branchpoint_group_inds(
len(self.par_inds),
self.child_belongs_to_branchpoint,
self.cumsum_nseg[-1],
)
self.children_in_level = merge_cells(
children_in_level = merge_cells(
self.cumsum_nbranches,
self.cumsum_nbranchpoints_per_cell,
[cell.children_in_level for cell in self.cells],
[cell.solve_indexer.children_in_level for cell in self.cells],
exclude_first=False,
)
self.parents_in_level = merge_cells(
parents_in_level = merge_cells(
self.cumsum_nbranches,
self.cumsum_nbranchpoints_per_cell,
[cell.parents_in_level for cell in self.cells],
[cell.solve_indexer.parents_in_level for cell in self.cells],
exclude_first=False,
)
self.root_inds = self.cumsum_nbranches[:-1]
padded_cumsum_nseg = cumsum_leading_zero(
np.concatenate(
[np.diff(cell.solve_indexer.cumsum_nseg) for cell in self.cells]
)
)

# Generate mapping to dealing with the masking which allows using the custom
# sparse solver to deal with different nseg per branch.
self._remapped_node_indices = remap_index_to_masked(
remapped_node_indices = remap_index_to_masked(
self._internal_node_inds,
self.nodes,
self.nseg,
padded_cumsum_nseg,
self.nseg_per_branch,
)
self.solve_indexer = JaxleySolveIndexer(
cumsum_nseg=padded_cumsum_nseg,
branchpoint_group_inds=branchpoint_group_inds,
children_in_level=children_in_level,
parents_in_level=parents_in_level,
root_inds=self.cumsum_nbranches[:-1],
remapped_node_indices=remapped_node_indices,
)

def _init_morph_jax_spsolve(self):
"""Initialize the morphology for networks.
Expand Down
Loading

0 comments on commit d169f2c

Please sign in to comment.