Skip to content

Commit

Permalink
fix: small fixes and comments added
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 16, 2024
1 parent 5120d7b commit 387d601
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,29 +1215,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:

def _iter_states_params(
self, params=False, states=False, currents=False
) -> Tuple[str, jnp.ndarray]:
) -> Tuple[str, np.ndarray]: # type: ignore
# assert that either params or states is True
assert params or states or currents, "Select either params / states / currents."
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]

current_names = self.membrane_current_names + self.synapse_current_names
channel_currents = [c.current_name for c in self.channels]

all_mechs = self.channels + self.synapses
all_states = sum([list(m.states) for m in all_mechs], []) + global_states
all_params = sum([list(m.params) for m in all_mechs], []) + global_params


if params:
for key in all_states:
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
all_params = sum([list(m.params) for m in all_mechs], []) + global_params
for key in all_params:
yield key, self._inds_of_state_param(key)

if states:
for key in all_params:
global_states = ["v"]
all_states = sum([list(m.states) for m in all_mechs], []) + global_states
for key in all_states:
yield key, self._inds_of_state_param(key)

if currents:
for key in current_names + channel_currents:
current_names = self.membrane_current_names + self.synapse_current_names
for key in current_names:
yield key, self._inds_of_state_param(key)

def _prepare_for_jax(self):
Expand Down Expand Up @@ -1275,9 +1272,9 @@ def inds_of_key(key: str) -> np.ndarray:
for mech in self.channels + self.synapses:
mech.indices = self._inds_of_state_param(mech._name)
mech._jax_inds = {}
currents = {mech.current_name: None} if isinstance(mech, Channel) else {}
current = {mech.current_name: None} if isinstance(mech, Channel) else {}

for param_state in {**mech.params, **mech.states, **currents}:
for param_state in {**mech.params, **mech.states, **current}:
is_global = not param_state.startswith(f"{mech._name}_")
if is_global:
global_inds = self._inds_of_state_param(param_state)
Expand Down

0 comments on commit 387d601

Please sign in to comment.