Skip to content

Commit

Permalink
fix: fix jitting issues of to_jax!
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 13, 2024
1 parent d43ee80 commit b652e4a
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,6 @@ def _gather_channels_from_constituents(self, constituents: List):
self.base.nodes.loc[self.nodes[name].isna(), name] = False

def to_jax(self):
# TODO FROM #447: Make this work for View?
"""Move `.nodes` to `.jaxnodes`.
Before the actual simulation is run (via `jx.integrate`), all parameters of
Expand All @@ -741,12 +740,15 @@ def to_jax(self):

jaxnodes, jaxedges = {}, {}

nodes = self.nodes.to_dict(orient="list")
edges = self.edges.to_dict(orient="list")

for key, inds in self._inds_of_state_param.items():
data = self.nodes if key in self.nodes.columns else self.edges
data = nodes if key in self.nodes.columns else edges
jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges

inds = self._inds_of_state_param[key]
values = data.loc[inds, key].to_numpy()
values = jnp.asarray(data[key])[inds]
jax_arrays.update({key: values})

self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()}
Expand Down Expand Up @@ -1101,7 +1103,6 @@ def make_trainable(
f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}"
)

@only_allow_module
def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
"""Write the trainables into `.nodes` and `.edges`.
Expand All @@ -1110,10 +1111,6 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
Args:
trainable_params: The trainable parameters returned by `get_parameters()`.
"""
# We do not support views. Why? `jaxedges` does not have any NaN
# elements, whereas edges does. Because of this, we already need special
# treatment to make this function work, and it would be an even bigger hassle
# if we wanted to support this.
assert self.__class__.__name__ in [
"Compartment",
"Branch",
Expand Down Expand Up @@ -1142,7 +1139,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
for parameter in pstate:
key = parameter["key"]
mech_inds = self._inds_of_state_param[key]
data = self.nodes if key in self.nodes.columns else self.edges
data = (
self.base.nodes if key in self.base.nodes.columns else self.base.edges
)
data.loc[mech_inds, key] = all_params_states[key]

def distance(self, endpoint: "View") -> float:
Expand Down Expand Up @@ -1259,7 +1258,6 @@ def inds_of_key(key):
param_state_inds = inds_of_key(key) if is_global else inds
self._inds_of_state_param[key] = jnp.asarray(param_state_inds)

@only_allow_module
def _get_all_states_params(
self,
pstate: List[Dict],
Expand Down Expand Up @@ -1305,7 +1303,6 @@ def _get_all_states_params(
)
return states_params

@only_allow_module
def get_all_parameters(
self, pstate: List[Dict], voltage_solver: str
) -> Dict[str, jnp.ndarray]:
Expand Down Expand Up @@ -1346,7 +1343,6 @@ def get_all_parameters(
)
return params

@only_allow_module
def get_all_states(
self, pstate: List[Dict], all_params, delta_t: float
) -> Dict[str, jnp.ndarray]:
Expand Down Expand Up @@ -1376,7 +1372,6 @@ def _initialize(self):
self._init_morph()
return self

@only_allow_module
def init_states(self, delta_t: float = 0.025):
# TODO FROM #447: MAKE THIS WORK FOR VIEW?
"""Initialize all mechanisms in their steady state.
Expand Down Expand Up @@ -1413,7 +1408,7 @@ def init_states(self, delta_t: float = 0.025):
# Note that we are overriding `self.nodes` here, but `self.nodes` is
# not used above to actually compute the current states (so there are
# no issues with overriding states).
self.nodes.loc[channel.indices, key] = val
self.base.nodes.loc[channel.indices, key] = val

def _init_morph_for_debugging(self):
"""Instandiates row and column inds which can be used to solve the voltage eqs.
Expand Down

0 comments on commit b652e4a

Please sign in to comment.