Skip to content

Commit

Permalink
move around
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 9, 2023
1 parent 8654f3e commit 6e0976b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 39 deletions.
2 changes: 1 addition & 1 deletion neurax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from neurax.integrate import integrate
from neurax.modules import *
from neurax.optimize import ParamTransform
from neurax.stimulus import step_current, step_dataset
from neurax.stimulus import datapoint_to_step_currents, step_current
33 changes: 31 additions & 2 deletions neurax/modules/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import inspect
from abc import ABC, abstractmethod
from copy import deepcopy
from math import pi
from typing import Callable, Dict, List, Optional, Union

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax.lax import ScatterDimensionNumbers, scatter_add

from neurax.channels import Channel
from neurax.solver_voltage import step_voltage_explicit, step_voltage_implicit
from neurax.stimulus import get_external_input
from neurax.synapses import Synapse


Expand Down Expand Up @@ -461,7 +462,7 @@ def step(
)

# External input.
i_ext = get_external_input(
i_ext = self.get_external_input(
voltages, i_inds, i_current, params["radius"], params["length"]
)

Expand Down Expand Up @@ -560,6 +561,34 @@ def _step_synapse(
voltages = u["voltages"]
return [{}], jnp.zeros_like(voltages), jnp.zeros_like(voltages)

@staticmethod
def get_external_input(
voltages: jnp.ndarray,
i_inds: jnp.ndarray,
i_stim: jnp.ndarray,
radius: float,
length_single_compartment: float,
):
"""
Return external input to each compartment in uA / cm^2.
"""
zero_vec = jnp.zeros_like(voltages)
# `radius`: um
# `length_single_compartment`: um
# `i_stim`: nA
current = (
i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds]
) # nA / um^2
current *= 100_000 # Convert (nA / um^2) to (uA / cm^2)

dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)
return stim_at_timestep


class View:
"""View of a `Module`."""
Expand Down
45 changes: 9 additions & 36 deletions neurax/stimulus.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from math import pi
from typing import List, Optional

import jax.numpy as jnp
from jax.lax import ScatterDimensionNumbers, scatter_add

from neurax.utils.cell_utils import index_of_loc


def step_current(
Expand All @@ -17,6 +11,10 @@ def step_current(
):
"""
Return step current in unit nA.
Unlike the `datapoint_to_step()` method, this takes a single value for the amplitude
and returns a single step current. The output of this function can be passed to
`.stimulate()`, but not to `integrate(..., currents=)`.
"""
dt = delta_t
window_start = int(i_delay / dt)
Expand All @@ -26,7 +24,7 @@ def step_current(
return current.at[window_start:window_end].set(i_amp)


def step_dataset(
def datapoint_to_step_currents(
i_delay: float,
i_dur: float,
i_amp: jnp.asarray,
Expand All @@ -36,40 +34,15 @@ def step_dataset(
):
"""
Return several step currents in unit nA.
Unlike the `step_current()` method, this takes a vector of amplitude and returns
a step current for each value in the vector. The output of this function can be
passed to `integrate(..., currents=)`, but can not be passed to `.stimulate()`.
"""
dim = len(i_amp)
dt = delta_t
window_start = int(i_delay / dt)
window_end = int((i_delay + i_dur) / dt)

time_steps = int(t_max // dt) + 2
current = jnp.zeros((time_steps, dim)) + i_offset
return current.at[window_start:window_end, :].set(i_amp).T


def get_external_input(
voltages: jnp.ndarray,
i_inds: jnp.ndarray,
i_stim: jnp.ndarray,
radius: float,
length_single_compartment: float,
):
"""
Return external input to each compartment in uA / cm^2.
"""
zero_vec = jnp.zeros_like(voltages)
# `radius`: um
# `length_single_compartment`: um
# `i_stim`: nA
current = (
i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds]
) # nA / um^2
current *= 100_000 # Convert (nA / um^2) to (uA / cm^2)

dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)
return stim_at_timestep

0 comments on commit 6e0976b

Please sign in to comment.