From 409f9ec7969e186e1ad7aa84552ffa61f2909ddf Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Thu, 9 Nov 2023 09:50:16 +0100 Subject: [PATCH] move around --- neurax/modules/base.py | 34 +++++++++++++++++++++++++++++-- neurax/stimulus.py | 45 +++++++++--------------------------------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/neurax/modules/base.py b/neurax/modules/base.py index d0200e3f4..ddd488578 100644 --- a/neurax/modules/base.py +++ b/neurax/modules/base.py @@ -1,3 +1,4 @@ +from math import pi import inspect from abc import ABC, abstractmethod from copy import deepcopy @@ -6,10 +7,10 @@ import jax.numpy as jnp import numpy as np import pandas as pd +from jax.lax import ScatterDimensionNumbers, scatter_add from neurax.channels import Channel from neurax.solver_voltage import step_voltage_explicit, step_voltage_implicit -from neurax.stimulus import get_external_input from neurax.synapses import Synapse @@ -461,7 +462,7 @@ def step( ) # External input. - i_ext = get_external_input( + i_ext = self.get_external_input( voltages, i_inds, i_current, params["radius"], params["length"] ) @@ -560,6 +561,35 @@ def _step_synapse( voltages = u["voltages"] return [{}], jnp.zeros_like(voltages), jnp.zeros_like(voltages) + @staticmethod + def get_external_input( + voltages: jnp.ndarray, + i_inds: jnp.ndarray, + i_stim: jnp.ndarray, + radius: float, + length_single_compartment: float, + ): + """ + Return external input to each compartment in uA / cm^2. + """ + zero_vec = jnp.zeros_like(voltages) + # `radius`: um + # `length_single_compartment`: um + # `i_stim`: nA + current = ( + i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds] + ) # nA / um^2 + current *= 100_000 # Convert (nA / um^2) to (uA / cm^2) + + dnums = ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,), + ) + stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums) + return stim_at_timestep + + class View: """View of a `Module`.""" diff --git a/neurax/stimulus.py b/neurax/stimulus.py index dd1322c09..584cee886 100644 --- a/neurax/stimulus.py +++ b/neurax/stimulus.py @@ -1,10 +1,4 @@ -from math import pi -from typing import List, Optional - import jax.numpy as jnp -from jax.lax import ScatterDimensionNumbers, scatter_add - -from neurax.utils.cell_utils import index_of_loc def step_current( @@ -17,6 +11,10 @@ def step_current( ): """ Return step current in unit nA. + + Unlike the `datapoint_to_step()` method, this takes a single value for the amplitude + and returns a single step current. The output of this function can be passed to + `.stimulate()`, but not to `integrate(..., currents=)`. """ dt = delta_t window_start = int(i_delay / dt) @@ -26,7 +24,7 @@ def step_current( return current.at[window_start:window_end].set(i_amp) -def step_dataset( +def datapoint_to_step_currents( i_delay: float, i_dur: float, i_amp: jnp.asarray, @@ -36,40 +34,15 @@ def step_dataset( ): """ Return several step currents in unit nA. + + Unlike the `step_current()` method, this takes a vector of amplitude and returns + a step current for each value in the vector. The output of this function can be + passed to `integrate(..., currents=)`, but can not be passed to `.stimulate()`. """ dim = len(i_amp) dt = delta_t window_start = int(i_delay / dt) window_end = int((i_delay + i_dur) / dt) - time_steps = int(t_max // dt) + 2 current = jnp.zeros((time_steps, dim)) + i_offset return current.at[window_start:window_end, :].set(i_amp).T - - -def get_external_input( - voltages: jnp.ndarray, - i_inds: jnp.ndarray, - i_stim: jnp.ndarray, - radius: float, - length_single_compartment: float, -): - """ - Return external input to each compartment in uA / cm^2. - """ - zero_vec = jnp.zeros_like(voltages) - # `radius`: um - # `length_single_compartment`: um - # `i_stim`: nA - current = ( - i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds] - ) # nA / um^2 - current *= 100_000 # Convert (nA / um^2) to (uA / cm^2) - - dnums = ScatterDimensionNumbers( - update_window_dims=(), - inserted_window_dims=(0,), - scatter_dims_to_operand_dims=(0,), - ) - stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums) - return stim_at_timestep