Skip to content

Commit

Permalink
Allow ion diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Oct 3, 2024
1 parent 4fc75a1 commit 5ad4d69
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 39 deletions.
100 changes: 63 additions & 37 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def __init__(self):
self.channels: List[Channel] = []
self.membrane_current_names: List[str] = []

# List of all states (exluding voltage) that are being diffused.
self.diffusion_states: List[str] = []

# For trainable parameters.
self.indices_set_by_trainables: List[jnp.ndarray] = []
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
Expand Down Expand Up @@ -395,6 +398,10 @@ def _data_set(
raise KeyError("Key not recognized.")
return param_state

def diffuse(self, state: str):
self.diffusion_states.append(state)
self.nodes[f"axial_resistivity_{state}"] = 1.0

def make_trainable(
self,
key: str,
Expand Down Expand Up @@ -548,6 +555,11 @@ def get_all_parameters(
for key in ["radius", "length", "axial_resistivity", "capacitance"]:
params[key] = self.jaxnodes[key]

for key in self.diffusion_states:
params[f"axial_resistivity_{key}"] = self.jaxnodes[
f"axial_resistivity_{key}"
]

for channel in self.channels:
for channel_params in channel.channel_params:
params[channel_params] = self.jaxnodes[channel_params]
Expand Down Expand Up @@ -952,25 +964,32 @@ def step(
cm = params["capacitance"] # Abbreviation.

# Arguments used by all solvers.
solver_kwargs = {
"voltages": voltages,
"voltage_terms": (v_terms + syn_v_terms) / cm,
"constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
"axial_conductances": params["axial_conductances"],
"internal_node_inds": self._internal_node_inds,
state_vals = {
"voltages": jnp.stack([voltages, u["CaCon_i"]]),
"voltage_terms": jnp.stack(
[(v_terms + syn_v_terms) / cm, jnp.zeros_like(v_terms)]
),
"constant_terms": jnp.stack(
[
(const_terms + i_ext + syn_const_terms) / cm,
jnp.zeros_like(const_terms),
]
),
"axial_conductances": jnp.stack(
[params["axial_conductances"], params["axial_conductances"]]
),
}

# Add solver specific arguments.
if voltage_solver == "jax.sparse":
solver_kwargs.update(
{
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"data_inds": self._data_inds,
"indices": self._indices_jax_spsolve,
"indptr": self._indptr_jax_spsolve,
"n_nodes": self._n_nodes,
}
)
solver_kwargs = {
"data_inds": self._data_inds,
"indices": self._indices_jax_spsolve,
"indptr": self._indptr_jax_spsolve,
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"n_nodes": self._n_nodes,
"internal_node_inds": self._internal_node_inds,
}
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_jax_spsolve
else:
Expand All @@ -980,42 +999,49 @@ def step(
# Currently, the forward Euler solver also uses this format. However,
# this is only for historical reasons and we are planning to change this in
# the future.
solver_kwargs.update(
{
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"sources": np.asarray(self._comp_edges["source"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"masked_node_inds": self._remapped_node_indices,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"debug_states": self.debug_states,
}
)
solver_kwargs = {
"internal_node_inds": self._internal_node_inds,
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"sources": np.asarray(self._comp_edges["source"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"masked_node_inds": self._remapped_node_indices,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"debug_states": self.debug_states,
}
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve

if solver == "bwd_euler":
u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)
nones = [None] * len(solver_kwargs)
vmapped = vmap(step_voltage_implicit, in_axes=(0, 0, 0, 0, *nones, None))
updated_states = vmapped(
*state_vals.values(), *solver_kwargs.values(), delta_t
)
u["v"] = updated_states[0]
u["CaCon_i"] = updated_states[1]
elif solver == "crank_nicolson":
# Crank-Nicolson advances by half a step of backward and half a step of
# forward Euler.
half_step_delta_t = delta_t / 2
half_step_voltages = step_voltage_implicit(
**solver_kwargs, delta_t=half_step_delta_t
**state_vals, **solver_kwargs, delta_t=half_step_delta_t
)
# The forward Euler step in Crank-Nicolson can be performed easily as
# `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.
u["v"] = 2 * half_step_voltages - voltages
elif solver == "fwd_euler":
u["v"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)
u["v"] = step_voltage_explicit(
**state_vals, **solver_kwargs, delta_t=delta_t
)
else:
raise ValueError(
f"You specified `solver={solver}`. The only allowed solvers are "
Expand Down
4 changes: 2 additions & 2 deletions jaxley/solver_voltage.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def step_voltage_implicit_with_jaxley_spsolve(
child_inds: jnp.ndarray,
nbranches: int,
solver: str,
delta_t: float,
children_in_level: List[jnp.ndarray],
parents_in_level: List[jnp.ndarray],
root_inds: jnp.ndarray,
branchpoint_group_inds: jnp.ndarray,
debug_states,
delta_t: float,
):
"""Solve one timestep of branched nerve equations with implicit (backward) Euler."""
# Build diagonals.
Expand Down Expand Up @@ -246,9 +246,9 @@ def step_voltage_implicit_with_jax_spsolve(
indices,
indptr,
sinks,
delta_t,
n_nodes,
internal_node_inds,
delta_t,
):
axial_conductances = delta_t * axial_conductances

Expand Down

0 comments on commit 5ad4d69

Please sign in to comment.