diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 73151e50..1936bc29 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -723,7 +723,6 @@ def _gather_channels_from_constituents(self, constituents: List): self.base.nodes.loc[self.nodes[name].isna(), name] = False def to_jax(self): - # TODO FROM #447: Make this work for View? """Move `.nodes` to `.jaxnodes`. Before the actual simulation is run (via `jx.integrate`), all parameters of @@ -741,12 +740,15 @@ def to_jax(self): jaxnodes, jaxedges = {}, {} + nodes = self.nodes.to_dict(orient="list") + edges = self.edges.to_dict(orient="list") + for key, inds in self._inds_of_state_param.items(): - data = self.nodes if key in self.nodes.columns else self.edges + data = nodes if key in self.nodes.columns else edges jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges inds = self._inds_of_state_param[key] - values = data.loc[inds, key].to_numpy() + values = jnp.asarray(data[key])[inds] jax_arrays.update({key: values}) self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} @@ -1101,7 +1103,6 @@ def make_trainable( f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}" ) - @only_allow_module def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): """Write the trainables into `.nodes` and `.edges`. @@ -1110,10 +1111,6 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): Args: trainable_params: The trainable parameters returned by `get_parameters()`. """ - # We do not support views. Why? `jaxedges` does not have any NaN - # elements, whereas edges does. Because of this, we already need special - # treatment to make this function work, and it would be an even bigger hassle - # if we wanted to support this. assert self.__class__.__name__ in [ "Compartment", "Branch", @@ -1142,7 +1139,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): for parameter in pstate: key = parameter["key"] mech_inds = self._inds_of_state_param[key] - data = self.nodes if key in self.nodes.columns else self.edges + data = ( + self.base.nodes if key in self.base.nodes.columns else self.base.edges + ) data.loc[mech_inds, key] = all_params_states[key] def distance(self, endpoint: "View") -> float: @@ -1259,7 +1258,6 @@ def inds_of_key(key): param_state_inds = inds_of_key(key) if is_global else inds self._inds_of_state_param[key] = jnp.asarray(param_state_inds) - @only_allow_module def _get_all_states_params( self, pstate: List[Dict], @@ -1305,7 +1303,6 @@ def _get_all_states_params( ) return states_params - @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str ) -> Dict[str, jnp.ndarray]: @@ -1346,7 +1343,6 @@ def get_all_parameters( ) return params - @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: @@ -1376,7 +1372,6 @@ def _initialize(self): self._init_morph() return self - @only_allow_module def init_states(self, delta_t: float = 0.025): # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. @@ -1413,7 +1408,7 @@ def init_states(self, delta_t: float = 0.025): # Note that we are overriding `self.nodes` here, but `self.nodes` is # not used above to actually compute the current states (so there are # no issues with overriding states). - self.nodes.loc[channel.indices, key] = val + self.base.nodes.loc[channel.indices, key] = val def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs.