Skip to content

Commit

Permalink
Run black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 15, 2024
1 parent 4f72578 commit 0ac7d75
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
14 changes: 8 additions & 6 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from math import prod
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import jax.numpy as jnp
import pandas as pd

from jaxley.modules import Module
from jaxley.utils.jax_utils import nested_checkpoint_scan
Expand Down Expand Up @@ -48,11 +49,12 @@ def integrate(
i_current = module.currents.T
i_inds = module.current_inds.comp_index.to_numpy()

# Append stimuli from `data_stimuli`.
i_current = jnp.concatenate(
[i_current, jnp.expand_dims(data_stimuli[0], axis=0)]
)
i_inds = np.concatenate([i_inds, data_stimuli[1].comp_index.to_numpy()])
if data_stimuli is not None:
# Append stimuli from `data_stimuli`.
i_current = jnp.concatenate(
[i_current, jnp.expand_dims(data_stimuli[0], axis=0)]
)
i_inds = np.concatenate([i_inds, data_stimuli[1].comp_index.to_numpy()])
else:
i_current = data_stimuli[0]
i_inds = data_stimuli[1].comp_index.to_numpy()
Expand Down
20 changes: 14 additions & 6 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,15 @@ def _stimulate(self, current, view):
self.currents = jnp.expand_dims(current, axis=0)
self.current_inds = pd.concat([self.current_inds, view])


def data_stimulate(self, current, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]]):
def data_stimulate(
self, current, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]]
):
"""Insert a stimulus into the module within jit (or grad)."""
return self._data_stimulate(current, self.nodes)

def _data_stimulate(self, current, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]], view):
def _data_stimulate(
self, current, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]], view
):
assert (
len(view) == 1
), "Can only stimulate compartments, not branches, cells, or networks."
Expand All @@ -404,9 +407,7 @@ def _data_stimulate(self, current, data_stimuli: Optional[Tuple[jnp.ndarray, pd.

# Same as in `.stimulate()`.
if currents is not None:
currents = jnp.concatenate(
[currents, jnp.expand_dims(current, axis=0)]
)
currents = jnp.concatenate([currents, jnp.expand_dims(current, axis=0)])
else:
currents = jnp.expand_dims(current, axis=0)
inds = pd.concat([inds, view])
Expand Down Expand Up @@ -787,6 +788,13 @@ def stimulate(self, current: Optional[jnp.ndarray] = None):
nodes = self.set_global_index_and_index(self.view)
self.pointer._stimulate(current, nodes)

def data_stimulate(
self, current, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]]
):
"""Insert a stimulus into the module within jit (or grad)."""
nodes = self.set_global_index_and_index(self.view)
self.pointer._data_stimulate(current, data_stimuli, nodes)

def set(self, key: str, val: float):
"""Set parameters of the pointer."""
self.pointer._set(key, val, self.view, self.pointer.nodes)
Expand Down

0 comments on commit 0ac7d75

Please sign in to comment.