From 387d60121d9e032d0a50b395332c0975d5e6590b Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 16 Dec 2024 18:07:50 +0100 Subject: [PATCH] fix: small fixes and comments added --- jaxley/modules/base.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index bbb30324..6d7a74f6 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1215,29 +1215,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: def _iter_states_params( self, params=False, states=False, currents=False - ) -> Tuple[str, jnp.ndarray]: + ) -> Tuple[str, np.ndarray]: # type: ignore # assert that either params or states is True assert params or states or currents, "Select either params / states / currents." - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - - current_names = self.membrane_current_names + self.synapse_current_names - channel_currents = [c.current_name for c in self.channels] - all_mechs = self.channels + self.synapses - all_states = sum([list(m.states) for m in all_mechs], []) + global_states - all_params = sum([list(m.params) for m in all_mechs], []) + global_params - + if params: - for key in all_states: + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + all_params = sum([list(m.params) for m in all_mechs], []) + global_params + for key in all_params: yield key, self._inds_of_state_param(key) if states: - for key in all_params: + global_states = ["v"] + all_states = sum([list(m.states) for m in all_mechs], []) + global_states + for key in all_states: yield key, self._inds_of_state_param(key) if currents: - for key in current_names + channel_currents: + current_names = self.membrane_current_names + self.synapse_current_names + for key in current_names: yield key, self._inds_of_state_param(key) def _prepare_for_jax(self): @@ -1275,9 +1272,9 @@ def inds_of_key(key: str) -> np.ndarray: for mech in self.channels + self.synapses: mech.indices = self._inds_of_state_param(mech._name) mech._jax_inds = {} - currents = {mech.current_name: None} if isinstance(mech, Channel) else {} + current = {mech.current_name: None} if isinstance(mech, Channel) else {} - for param_state in {**mech.params, **mech.states, **currents}: + for param_state in {**mech.params, **mech.states, **current}: is_global = not param_state.startswith(f"{mech._name}_") if is_global: global_inds = self._inds_of_state_param(param_state)