Skip to content

Commit

Permalink
Make several methods private; delete unused methods (#459)
Browse files Browse the repository at this point in the history
* Make several methods private; delete unused methods

* remove init_syns()
  • Loading branch information
michaeldeistler authored Oct 23, 2024
1 parent 26acd32 commit 61c5c27
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 61 deletions.
2 changes: 1 addition & 1 deletion jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def integrate(
the `Module` can be set with `set_states`.
"""

assert module.initialized, "Module is not initialized, run `.initialize()`."
assert module.initialized, "Module is not initialized, run `._initialize()`."
module.to_jax() # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.

# Initialize the external inputs and their indices.
Expand Down
19 changes: 8 additions & 11 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def show(

return nodes[cols]

def init_morph(self):
def _init_morph(self):
"""Initialize the morphology such that it can be processed by the solvers."""
self._init_morph_jaxley_spsolve()
self._init_morph_jax_spsolve()
Expand Down Expand Up @@ -963,7 +963,7 @@ def set_ncomp(
self.base._internal_node_inds = internal_node_inds

# Update the morphology indexing (e.g., `.comp_edges`).
self.base.initialize()
self.base._initialize()
self.base._init_view()
self.base._update_local_indices()

Expand Down Expand Up @@ -1173,7 +1173,7 @@ def get_all_parameters(
return params

# TODO: MAKE THIS WORK FOR VIEW?
def get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:
def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:
"""Return states as they are set in the `.nodes` and `.edges` tables."""
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
states = {"v": self.base.jaxnodes["v"]}
Expand All @@ -1199,7 +1199,7 @@ def get_all_states(
Returns:
A dictionary of all states of the module.
"""
states = self.base.get_states_from_nodes_and_edges()
states = self.base._get_states_from_nodes_and_edges()

# Override with the initial states set by `.make_trainable()`.
for parameter in pstate:
Expand Down Expand Up @@ -1227,11 +1227,11 @@ def get_all_states(
@property
def initialized(self) -> bool:
"""Whether the `Module` is ready to be solved or not."""
return self.initialized_morph and self.initialized_syns
return self.initialized_morph

def initialize(self):
def _initialize(self):
"""Initialize the module."""
self.init_morph()
self._init_morph()
return self

# TODO: MAKE THIS WORK FOR VIEW?
Expand All @@ -1245,7 +1245,7 @@ def init_states(self, delta_t: float = 0.025):
"""
# Update states of the channels.
channel_nodes = self.base.nodes
states = self.base.get_states_from_nodes_and_edges()
states = self.base._get_states_from_nodes_and_edges()

# We do not use any `pstate` for initializing. In principle, we could change
# that by allowing an input `params` and `pstate` to this function.
Expand Down Expand Up @@ -1557,9 +1557,6 @@ def insert(self, channel: Channel):
for key in channel.channel_states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]

def init_syns(self):
self.initialized_syns = True

def step(
self,
u: Dict[str, jnp.ndarray],
Expand Down
3 changes: 1 addition & 2 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def __init__(
)
self._internal_node_inds = jnp.arange(self.nseg)

self.initialize()
self.init_syns()
self._initialize()

# Coordinates.
self.xyzr = [float("NaN") * np.zeros((2, 4))]
Expand Down
27 changes: 1 addition & 26 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ def __init__(
compute_children_and_parents(self.branch_edges)
)

self.initialize()
self.init_syns()
self._initialize()

def _init_morph_jaxley_spsolve(self):
"""Initialize morphology for the custom sparse solver.
Expand Down Expand Up @@ -271,30 +270,6 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr

@staticmethod
def update_summed_coupling_conds_jaxley_spsolve(
summed_conds,
child_inds,
par_inds,
branchpoint_conds_children,
branchpoint_conds_parents,
):
"""Perform updates on the diagonal based on conductances of the branchpoints.
Args:
summed_conds: shape [num_branches, nseg]
child_inds: shape [num_branches - 1]
conds_fwd: shape [num_branches - 1]
conds_bwd: shape [num_branches - 1]
parents: shape [num_branches]
Returns:
Updated `summed_coupling_conds`.
"""
summed_conds = summed_conds.at[child_inds, 0].add(branchpoint_conds_children)
summed_conds = summed_conds.at[par_inds, -1].add(branchpoint_conds_parents)
return summed_conds


def read_swc(
fname: str,
Expand Down
9 changes: 1 addition & 8 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def __init__(self):
self._internal_node_inds = jnp.asarray([0])

# Initialize the module.
self.initialize()
self.init_syns()
self._initialize()

# Coordinates.
self.xyzr = [float("NaN") * np.zeros((2, 4))]
Expand Down Expand Up @@ -93,9 +92,3 @@ def _init_morph_jax_spsolve(self):
self._data_inds = data_inds
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr

def init_conds(self, params: Dict[str, jnp.ndarray]):
"""Override `Base.init_axial_conds()`.
This is because compartments do not have any axial conductances."""
return {"axial_conductances": jnp.asarray([])}
14 changes: 1 addition & 13 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def __init__(
# Channels.
self._gather_channels_from_constituents(cells)

self.initialize()
self.init_syns()
self._initialize()

def __repr__(self):
return f"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details."
Expand Down Expand Up @@ -239,17 +238,6 @@ def _init_morph_jax_spsolve(self):
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr

def init_syns(self):
"""Initialize synapses."""
self.synapses = []

# TODO(@michaeldeistler): should we also track this for channels?
self.synapse_names = []
self.synapse_param_names = []
self.synapse_state_names = []

self.initialized_syns = True

def _step_synapse(
self,
states: Dict,
Expand Down

0 comments on commit 61c5c27

Please sign in to comment.