Skip to content

Commit

Permalink
User exposing init and step for customization. (#466)
Browse files Browse the repository at this point in the history
* Working init and step

* Refactored integrate

* Adding API equivalence test

* Fixing delta_t in APU test
  • Loading branch information
manuelgloeckler authored Oct 25, 2024
1 parent 84c87fc commit 3bf221c
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 43 deletions.
195 changes: 152 additions & 43 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from math import prod
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import pandas as pd
Expand All @@ -12,6 +12,151 @@
from jaxley.utils.jax_utils import nested_checkpoint_scan


def build_init_and_step_fn(
module: Module,
voltage_solver: str = "jaxley.stone",
solver: str = "bwd_euler",
) -> Tuple[Callable, Callable]:
"""This function returns the `init_fn` and `step_fn` which initialize the
parameters and states of the neuron model and then step through the model
Args:
module (Module): A `Module` object that e.g. a cell.
voltage_solver (str, optional): Voltage solver used in step. Defaults to "jaxley.stone".
solver (str, optional): ODE solver. Defaults to "bwd_euler".
Returns:
init_fn, step_fn: Functions that initialize the state and parameters, and perform
a single integration step, respectively.
"""
# Initialize the external inputs and their indices.
external_inds = module.external_inds.copy()

def init_fn(
params: List[Dict[str, jnp.ndarray]],
all_states: Optional[Dict] = None,
param_state: Optional[List[Dict]] = None,
delta_t: float = 0.025,
) -> Tuple[Dict, Dict]:
"""Initializes the parameters and states of the neuron model.
Args:
params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.
all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.
param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.
delta_t (float, optional): Step size. Defaults to 0.025.
Returns:
Tuple[Dict, Dict]: All states and parameters.
"""
# Make the `trainable_params` of the same shape as the `param_state`, such that
# they can be processed together by `get_all_parameters`.
pstate = params_to_pstate(params, module.indices_set_by_trainables)
if param_state is not None:
pstate += param_state

all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)
all_states = (
module.get_all_states(pstate, all_params, delta_t)
if all_states is None
else all_states
)
return all_states, all_params

def step_fn(
all_states: Dict,
all_params: Dict,
externals: Dict,
external_inds: Dict = external_inds,
delta_t: float = 0.025,
) -> Dict:
"""Performs a single integration step with step size delta_t.
Args:
all_states (Dict): Current state of the neuron model.
all_params (Dict): Current parameters of the neuron model.
externals (Dict): External inputs.
external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.
delta_t (float, optional): Time step. Defaults to 0.025.
Returns:
Dict: Updated states.
"""
state = all_states
state = module.step(
state,
delta_t,
external_inds,
externals,
params=all_params,
solver=solver,
voltage_solver=voltage_solver,
)
return state

return init_fn, step_fn


def add_stimuli(
externals: Dict,
external_inds: Dict,
data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
) -> Tuple[Dict, Dict]:
"""Extends the external inputs with the stimuli.
Args:
externals (Dict): Current external inputs.
external_inds (Dict): Current external indices.
data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.
Returns:
Tuple[Dict, Dict]: Updated external inputs and indices.
"""
# If stimulus is inserted, add it to the external inputs.
if "i" in externals.keys() or data_stimuli is not None:
if "i" in externals.keys():
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()

return externals, external_inds


def add_clamps(
externals: Dict,
external_inds: Dict,
data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,
) -> Tuple[Dict, Dict]:
"""Adds clamps to the external inputs.
Args:
externals (Dict): Current external inputs.
external_inds (Dict): Current external indices.
data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.
Returns:
Tuple[Dict, Dict]: Updated external inputs and indices.
"""
# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.global_comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.global_comp_index.to_numpy()

return externals, external_inds


def integrate(
module: Module,
params: List[Dict[str, jnp.ndarray]] = [],
Expand Down Expand Up @@ -70,28 +215,10 @@ def integrate(
external_inds = module.external_inds.copy()

# If stimulus is inserted, add it to the external inputs.
if "i" in module.externals.keys() or data_stimuli is not None:
if "i" in module.externals.keys():
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[2].global_comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].global_comp_index.to_numpy()
externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)

# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in module.externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.global_comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.global_comp_index.to_numpy()
externals, external_inds = add_clamps(externals, external_inds, data_clamps)

if not externals.keys():
# No stimulus was inserted and no clamp was set.
Expand Down Expand Up @@ -126,31 +253,13 @@ def integrate(
else:
externals[key] = externals[key][:t_max_steps, :]

# Make the `trainable_params` of the same shape as the `param_state`, such that they
# can be processed together by `get_all_parameters`.
pstate = params_to_pstate(params, module.indices_set_by_trainables)

# Gather parameters from `make_trainable` and `data_set` into a single list.
if param_state is not None:
pstate += param_state

all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)
all_states = (
module.get_all_states(pstate, all_params, delta_t)
if all_states is None
else all_states
init_fn, step_fn = build_init_and_step_fn(
module, voltage_solver=voltage_solver, solver=solver
)
all_states, all_params = init_fn(params, all_states, param_state, delta_t)

def _body_fun(state, externals):
state = module.step(
state,
delta_t,
external_inds,
externals,
params=all_params,
solver=solver,
voltage_solver=voltage_solver,
)
state = step_fn(state, all_params, externals, external_inds, delta_t)
recs = jnp.asarray(
[
state[rec_state][rec_ind]
Expand Down
42 changes: 42 additions & 0 deletions tests/test_api_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jaxley as jx
from jaxley.channels import HH
from jaxley.connect import connect
from jaxley.integrate import build_init_and_step_fn
from jaxley.synapses import IonotropicSynapse


Expand Down Expand Up @@ -229,3 +230,44 @@ def test_api_equivalence_network_matches_cell():

max_error = np.max(np.abs(voltages_net - voltages_cells))
assert max_error < 1e-8, f"Error is {max_error}"


def test_api_init_step_to_integrate():
comp = jx.Compartment()
branch = jx.Branch(comp, 2)
cell = jx.Cell(branch, parents=[-1, 0, 0])
cell.insert(HH())
cell[0, 1].record()

# Internal integration function API
delta_t = 0.025 # Default delta_t is 0.025
v1 = jx.integrate(cell, t_max=4.0, delta_t=delta_t)

# Flexibe init and step API
init_fn, step_fn = build_init_and_step_fn(cell)

params = cell.get_parameters()
states, params = init_fn(params)
step_fn_ = jax.jit(step_fn)
rec_inds = cell.recordings.rec_index.to_numpy()
rec_states = cell.recordings.state.to_numpy()

steps = int(4.0 / delta_t) # Steps to integrate
recordings = [
states[rec_state][rec_ind][None]
for rec_state, rec_ind in zip(rec_states, rec_inds)
]
externals = cell.externals
for _ in range(steps):
states = step_fn_(states, params, externals, delta_t=delta_t)
recs = jnp.asarray(
[
states[rec_state][rec_ind]
for rec_state, rec_ind in zip(rec_states, rec_inds)
]
)
recordings.append(recs)

rec = jnp.stack(recordings, axis=0).T

assert jnp.allclose(v1, rec)

0 comments on commit 3bf221c

Please sign in to comment.