Skip to content

Commit

Permalink
Bugfix for init_states() when channel does not exist in all comps (#…
Browse files Browse the repository at this point in the history
…421)

* Bugfix for using `init_states()` when channel does not exist in all comps

* add jaxley-mech to dev dependencies for testing

* add test for complex channel init_states

* bugfix for workflow

* adapt tutorial to #416

* formatting

* bugfix for pyproject.toml
  • Loading branch information
michaeldeistler authored Sep 16, 2024
1 parent c2c6687 commit 41ce5c3
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[dev]
pip install -e ".[dev]"
- name: Check formatting with black
run: |
Expand Down
10 changes: 10 additions & 0 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,13 @@ def compute_current(
Current in `uA/cm2`.
"""
raise NotImplementedError

def init_state(
self,
states: Dict[str, jnp.ndarray],
v: jnp.ndarray,
params: Dict[str, jnp.ndarray],
delta_t: float,
):
"""Initialize states of channel."""
return {}
40 changes: 26 additions & 14 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
convert_point_process_to_distributed,
interpolate_xyz,
loc_of_index,
query_channel_states_and_params,
v_interp,
)
from jaxley.utils.debug_solver import compute_morphology_indices, convert_to_csc
Expand Down Expand Up @@ -608,25 +609,37 @@ def init_states(self, delta_t: float = 0.025):
channel_nodes = self.nodes
states = self.get_states_from_nodes_and_edges()

# We do not use any `pstate` for initializing. In principle, we could change
# that by allowing an input `params` and `pstate` to this function.
params = self.get_all_parameters([])

for channel in self.channels:
name = channel._name
indices = channel_nodes.loc[channel_nodes[name]]["comp_index"].to_numpy()
voltages = channel_nodes.loc[indices, "v"].to_numpy()
channel_indices = channel_nodes.loc[channel_nodes[name]][
"comp_index"
].to_numpy()
voltages = channel_nodes.loc[channel_indices, "v"].to_numpy()

channel_param_names = list(channel.channel_params.keys())
channel_params = {}
for p in channel_param_names:
channel_params[p] = channel_nodes[p][indices].to_numpy()
channel_state_names = list(channel.channel_states.keys())
channel_states = query_channel_states_and_params(
states, channel_state_names, channel_indices
)
channel_params = query_channel_states_and_params(
params, channel_param_names, channel_indices
)

init_state = channel.init_state(states, voltages, channel_params, delta_t)
init_state = channel.init_state(
channel_states, voltages, channel_params, delta_t
)

# `init_state` might not return all channel states. Only the ones that are
# returned are updated here.
for key, val in init_state.items():
# Note that we are overriding `self.nodes` here, but `self.nodes` is
# not used above to actually compute the current states (so there are
# no issues with overriding states).
self.nodes.loc[indices, key] = val
self.nodes.loc[channel_indices, key] = val

def _init_morph_for_debugging(self):
"""Instandiates row and column inds which can be used to solve the voltage eqs.
Expand Down Expand Up @@ -982,11 +995,6 @@ def _step_channels_state(
"""One integration step of the channels."""
voltages = states["v"]

query = lambda d, keys, idcs: dict(
zip(keys, (v[idcs] for v in map(d.get, keys)))
) # get dict with subset of keys and values from d
# only loops over necessary keys, as opposed to looping over d.items()

# Update states of the channels.
indices = channel_nodes["comp_index"].to_numpy()
for channel in channels:
Expand All @@ -996,8 +1004,12 @@ def _step_channels_state(
channel_state_names += self.membrane_current_names
channel_indices = indices[channel_nodes[channel._name].astype(bool)]

channel_params = query(params, channel_param_names, channel_indices)
channel_states = query(states, channel_state_names, channel_indices)
channel_params = query_channel_states_and_params(
params, channel_param_names, channel_indices
)
channel_states = query_channel_states_and_params(
states, channel_state_names, channel_indices
)

states_updated = channel.update_states(
channel_states, delta_t, voltages[channel_indices], channel_params
Expand Down
15 changes: 15 additions & 0 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,18 @@ def group_and_sum(
group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)

return group_sums


def query_channel_states_and_params(d, keys, idcs):
"""Get dict with subset of keys and values from d.
This is used to restrict a dict where every item contains __all__ states to only
the ones that are relevant for the channel. E.g.
```states = {'eCa': Array([ 0., 0., nan]}```
will be
```states = {'eCa': Array([ 0., 0.]}```
Only loops over necessary keys, as opposed to looping over `d.items()`."""
return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ doc = [
dev = [
"black",
"isort",
"jaxley-mech",
"neuron",
"pytest",
"pyright",
Expand Down
90 changes: 89 additions & 1 deletion tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
from typing import Optional
from typing import Dict, Optional

import jax.numpy as jnp
import numpy as np
import pytest
from jaxley_mech.channels.l5pc import CaNernstReversal, CaPump

import jaxley as jx
from jaxley.channels import HH, CaL, CaT, Channel, K, Km, Leak, Na
from jaxley.solver_gate import save_exp, solve_inf_gate_exponential


def test_channel_set_name():
Expand Down Expand Up @@ -101,6 +103,92 @@ def test_init_states():
assert np.abs(v[0, 0] - v[0, -1]) < 0.02


class KCA11(Channel):
def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.channel_params = {
f"{prefix}_q10_ch": 3,
f"{prefix}_q10_ch0": 22,
"celsius": 22,
}
self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4}
self.current_name = f"i_K"

def update_states(
self,
states: Dict[str, jnp.ndarray],
dt,
v,
params: Dict[str, jnp.ndarray],
):
"""Update state."""
prefix = self._name
m = states[f"{prefix}_m"]
q10 = params[f"{prefix}_q10_ch"] ** (
(params["celsius"] - params[f"{prefix}_q10_ch0"]) / 10
)
cai = states["CaCon_i"]
new_m = solve_inf_gate_exponential(m, dt, *self.m_gate(v, cai, q10))
return {f"{prefix}_m": new_m}

def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
m = states[f"{prefix}_m"]
g = 0.03 * m * 1000 # mS/cm^2
return g * (v + 80.0)

def init_state(self, states, v, params, dt):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
q10 = params[f"{prefix}_q10_ch"] ** (
(params["celsius"] - params[f"{prefix}_q10_ch0"]) / 10
)
cai = states["CaCon_i"]
m_inf, _ = self.m_gate(v, cai, q10)
return {f"{prefix}_m": m_inf}

@staticmethod
def m_gate(v, cai, q10):
cai = cai * 1e3
v_half = -66 + 137 * save_exp(-0.3044 * cai) + 30.24 * save_exp(-0.04141 * cai)
alpha = 25.0

beta = 0.075 / save_exp((v - v_half) / 10)
m_inf = alpha / (alpha + beta)
tau_m = 1.0 * q10
return m_inf, tau_m


def test_init_states_complex_channel():
"""Test for `init_states()` with a more complicated channel model.
The channel model used for this test uses the `states` in `init_state` and it also
uses `q10`. The model inserts the channel only is some branches. This test follows
an issue I had with Jaxley in v0.2.0 (fixed in v0.2.1).
"""
## Create cell
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=1)
cell = jx.Cell(branch, parents=[-1, 0, 0])

# CA channels.
cell.branch([0, 1]).insert(CaNernstReversal())
cell.branch([0, 1]).insert(CaPump())
cell.branch([0, 1]).insert(KCA11())

cell.init_states()

current = jx.step_current(1.0, 1.0, 0.1, 0.025, 3.0)
cell.branch(2).comp(0).stimulate(current)
cell.branch(2).comp(0).record()
voltages = jx.integrate(cell)
assert np.invert(np.any(np.isnan(voltages))), "NaN voltage found"


def test_multiple_channel_currents():
"""Test whether all channels can"""

Expand Down
Loading

0 comments on commit 41ce5c3

Please sign in to comment.