Skip to content

Commit

Permalink
move around
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 9, 2023
1 parent 8654f3e commit dacf634
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 39 deletions.
2 changes: 1 addition & 1 deletion neurax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from neurax.integrate import integrate
from neurax.modules import *
from neurax.optimize import ParamTransform
from neurax.stimulus import step_current, step_dataset
from neurax.stimulus import step_current, datapoint_to_step_currents
34 changes: 32 additions & 2 deletions neurax/modules/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import pi
import inspect
from abc import ABC, abstractmethod
from copy import deepcopy
Expand All @@ -6,10 +7,10 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax.lax import ScatterDimensionNumbers, scatter_add

from neurax.channels import Channel
from neurax.solver_voltage import step_voltage_explicit, step_voltage_implicit
from neurax.stimulus import get_external_input
from neurax.synapses import Synapse


Expand Down Expand Up @@ -461,7 +462,7 @@ def step(
)

# External input.
i_ext = get_external_input(
i_ext = self.get_external_input(
voltages, i_inds, i_current, params["radius"], params["length"]
)

Expand Down Expand Up @@ -560,6 +561,35 @@ def _step_synapse(
voltages = u["voltages"]
return [{}], jnp.zeros_like(voltages), jnp.zeros_like(voltages)

@staticmethod
def get_external_input(
voltages: jnp.ndarray,
i_inds: jnp.ndarray,
i_stim: jnp.ndarray,
radius: float,
length_single_compartment: float,
):
"""
Return external input to each compartment in uA / cm^2.
"""
zero_vec = jnp.zeros_like(voltages)
# `radius`: um
# `length_single_compartment`: um
# `i_stim`: nA
current = (
i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds]
) # nA / um^2
current *= 100_000 # Convert (nA / um^2) to (uA / cm^2)

dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)
return stim_at_timestep



class View:
"""View of a `Module`."""
Expand Down
45 changes: 9 additions & 36 deletions neurax/stimulus.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from math import pi
from typing import List, Optional

import jax.numpy as jnp
from jax.lax import ScatterDimensionNumbers, scatter_add

from neurax.utils.cell_utils import index_of_loc


def step_current(
Expand All @@ -17,6 +11,10 @@ def step_current(
):
"""
Return step current in unit nA.
Unlike the `datapoint_to_step()` method, this takes a single value for the amplitude
and returns a single step current. The output of this function can be passed to
`.stimulate()`, but not to `integrate(..., currents=)`.
"""
dt = delta_t
window_start = int(i_delay / dt)
Expand All @@ -26,7 +24,7 @@ def step_current(
return current.at[window_start:window_end].set(i_amp)


def step_dataset(
def datapoint_to_step_currents(
i_delay: float,
i_dur: float,
i_amp: jnp.asarray,
Expand All @@ -36,40 +34,15 @@ def step_dataset(
):
"""
Return several step currents in unit nA.
Unlike the `step_current()` method, this takes a vector of amplitude and returns
a step current for each value in the vector. The output of this function can be
passed to `integrate(..., currents=)`, but can not be passed to `.stimulate()`.
"""
dim = len(i_amp)
dt = delta_t
window_start = int(i_delay / dt)
window_end = int((i_delay + i_dur) / dt)

time_steps = int(t_max // dt) + 2
current = jnp.zeros((time_steps, dim)) + i_offset
return current.at[window_start:window_end, :].set(i_amp).T


def get_external_input(
voltages: jnp.ndarray,
i_inds: jnp.ndarray,
i_stim: jnp.ndarray,
radius: float,
length_single_compartment: float,
):
"""
Return external input to each compartment in uA / cm^2.
"""
zero_vec = jnp.zeros_like(voltages)
# `radius`: um
# `length_single_compartment`: um
# `i_stim`: nA
current = (
i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds]
) # nA / um^2
current *= 100_000 # Convert (nA / um^2) to (uA / cm^2)

dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)
return stim_at_timestep

0 comments on commit dacf634

Please sign in to comment.