Skip to content

Commit

Permalink
wip: more tests passing some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 5, 2024
1 parent aa4ae5f commit 2ad0b92
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 195 deletions.
201 changes: 110 additions & 91 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def to_jax(self):
inds = (
all_inds[data["type"] == mech._name]
if "type" in data.columns
else all_inds[self.nodes[mech._name]]
else all_inds[data[mech._name]]
)
states_params = list(mech.params) + list(mech.states)
params = data[states_params].loc[inds]
Expand Down Expand Up @@ -1107,6 +1107,7 @@ def make_trainable(
f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}"
)

@only_allow_module
def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
"""Write the trainables into `.nodes` and `.edges`.
Expand All @@ -1131,22 +1132,26 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
# any kind of issues with indexing or parameter sharing (as this is fully
# taken care of by `get_all_parameters()`).
self.base.to_jax()
pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)
all_params = self.base.get_all_parameters(pstate, voltage_solver="jaxley.stone")

# The value for `delta_t` does not matter here because it is only used to
# compute the initial current. However, the initial current cannot be made
# trainable and so its value never gets used below.
all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)
pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)
all_params_states = self.base._get_all_states_params(
pstate,
delta_t=0.025,
voltage_solver="jaxley.stone",
params=True,
states=True,
)

# Loop only over the keys in `pstate` to avoid unnecessary computation.
for parameter in pstate:
key = parameter["key"]
vals_to_set = all_params if key in all_params.keys() else all_states
if key in self.base.nodes.columns:
self.base.nodes[key] = vals_to_set[key]
if key in self.base.edges.columns:
self.base.edges[key] = vals_to_set[key]
mech, mech_inds = self.base._get_mech_inds_of_param_state(key)
data = (
self.base.nodes if key in self.base.nodes.columns else self.base.edges
)
data.loc[mech_inds, key] = all_params_states[key]

def distance(self, endpoint: "View") -> float:
"""Return the direct distance between two compartments.
Expand Down Expand Up @@ -1219,26 +1224,87 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
"""
return self.trainable_params

@only_allow_module
def _iter_states_or_params(self, type="states") -> Dict[str, jnp.ndarray]:
def _iter_states_params(
self, params=False, states=False
) -> Tuple[str, jnp.ndarray, jnp.ndarray]:
# TODO FROM #447: MAKE THIS WORK FOR VIEW?
"""Return states as they are set in the `.nodes` and `.edges` tables."""

# assert that either params or states is True
assert params or states, "Either params or states must be True."
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]
global_states_or_params = morph_params if type == "params" else global_states
for key in global_states_or_params:
yield key, self.base.jaxnodes["index"], self.base.jaxnodes[key]
global_states_params = morph_params if params else []
global_states_params += ["v"] if states else []
for key in global_states_params:
yield key, self.jaxnodes["index"], self.jaxnodes[key]

# Join node and edge states into a single state dictionary.
for jax_arrays, mechs in zip(
[self.base.jaxnodes, self.base.jaxedges],
[self.base.channels, self.base.synapses],
[self.jaxnodes, self.jaxedges],
[self.channels, self.synapses],
):
for mech in mechs:
mech_inds = jax_arrays[mech._name]
for key in mech.__dict__[type]:
mech_params_states = mech.__dict__["params"] if params else {}
mech_params_states.update(mech.__dict__["states"] if states else {})
for key in mech_params_states:
yield key, mech_inds, jax_arrays[key]

def _get_mech_inds_of_param_state(self, key: str) -> Tuple[str, jnp.ndarray]:
jax_array = self.jaxnodes if key in self.nodes.columns else self.jaxedges

if "_" in key and key not in ["axial_resistivity", "axial_conductances"]:
mech = key.split("_")[0]
return mech, jax_array[mech]

return None, jax_array["index"]

@only_allow_module
def _get_all_states_params(
self,
pstate: List[Dict],
voltage_solver=None,
delta_t=None,
all_params=None,
params=False,
states=False,
) -> Dict[str, jnp.ndarray]:
states_params = {}
for key, _, jax_array in self.base._iter_states_params(params, states):
states_params[key] = jax_array

# Override with those parameters set by `.make_trainable()`.
for parameter in pstate:
key = parameter["key"]
inds = parameter["indices"]
set_param = parameter["val"]

if key in states_params:
mech, mech_inds = self.base._get_mech_inds_of_param_state(key)
# `inds` is of shape `(num_params, num_comps_per_param)`.
# `set_param` is of shape `(num_params,)`
# We need to unsqueeze `set_param` to make it `(num_params, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
inds = np.searchsorted(mech_inds, inds)
states_params[key] = states_params[key].at[inds].set(set_param[:, None])

if params:
# Compute conductance params and add them to the params dictionary.
states_params["axial_conductances"] = self.base._compute_axial_conductances(
params=states_params
)
if states:
all_params = states_params if all_params is None and params else all_params
# Add to the states the initial current through every channel.
states, _ = self.base._channel_currents(
states_params, delta_t, self.base.channels, self.base.nodes, all_params
)

# Add to the states the initial current through every synapse.
states, _ = self.base._synapse_currents(
states_params, self.base.synapses, all_params, delta_t, self.base.edges
)
return states_params

@only_allow_module
def get_all_parameters(
self, pstate: List[Dict], voltage_solver: str
Expand Down Expand Up @@ -1275,29 +1341,8 @@ def get_all_parameters(
Returns:
A dictionary of all module parameters.
"""
pstate_inds = {d["key"]: i for i, d in enumerate(pstate)}

params = {}
for key, mech_inds, jax_array in self._iter_states_or_params("params"):
params[key] = jax_array

# Override with those parameters set by `.make_trainable()`.
if key in pstate_inds:
idx = pstate_inds[key]
key = pstate[idx]["key"]
inds = pstate[idx]["indices"]
set_param = pstate[idx]["val"]

# `inds` is of shape `(num_params, num_comps_per_param)`.
# `set_param` is of shape `(num_params,)`
# We need to unsqueeze `set_param` to make it `(num_params, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
inds = np.searchsorted(mech_inds, inds)
params[key] = params[key].at[inds].set(set_param[:, None])

# Compute conductance params and add them to the params dictionary.
params["axial_conductances"] = self.base._compute_axial_conductances(
params=params
params = self._get_all_states_params(
pstate, params=True, voltage_solver=voltage_solver
)
return params

Expand All @@ -1316,33 +1361,8 @@ def get_all_states(
Returns:
A dictionary of all states of the module.
"""
pstate_inds = {d["key"]: i for i, d in enumerate(pstate)}
states = {}
for key, mech_inds, jax_array in self._iter_states_or_params("states"):
states[key] = jax_array

# Override with those parameters set by `.make_trainable()`.
if key in pstate_inds:
idx = pstate_inds[key]
key = pstate[idx]["key"]
inds = pstate[idx]["indices"]
set_param = pstate[idx]["val"]

# `inds` is of shape `(num_states, num_comps_per_param)`.
# `set_param` is of shape `(num_states,)`
# We need to unsqueeze `set_param` to make it `(num_states, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
inds = np.searchsorted(mech_inds, inds)
states[key] = states[key].at[inds].set(set_param[:, None])

# Add to the states the initial current through every channel.
states, _ = self.base._channel_currents(
states, delta_t, self.channels, self.nodes, all_params
)

# Add to the states the initial current through every synapse.
states, _ = self.base._synapse_currents(
states, self.synapses, all_params, delta_t, self.edges
states = self._get_all_states_params(
pstate, states=True, all_params=all_params, delta_t=delta_t
)
return states

Expand Down Expand Up @@ -1370,7 +1390,7 @@ def init_states(self, delta_t: float = 0.025):
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
channel_nodes = self.base.nodes
states = {}
for key, _, jax_array in self._iter_states_or_params("states"):
for key, _, jax_array in self.base._iter_states_params(states=True):
states[key] = jax_array

# We do not use any `pstate` for initializing. In principle, we could change
Expand Down Expand Up @@ -2503,26 +2523,25 @@ def _set_inds_in_view(
def _jax_arrays_in_view(self, pointer: Union[Module, View]):
"""Update jaxnodes/jaxedges to show only those currently in view."""
a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1]
jaxnodes = {} if pointer.jaxnodes is not None else None
if self.jaxnodes is not None:
comp_inds = pointer.jaxnodes["global_comp_index"]
common_inds = a_intersects_b_at(comp_inds, self._nodes_in_view)
jaxnodes = {
k: v[common_inds]
for k, v in pointer.jaxnodes.items()
if len(common_inds) > 0
}

jaxedges = {} if pointer.jaxedges is not None else None
if pointer.jaxedges is not None:
for key, values in self.base.jaxedges.items():
if (syn_name := key.split("_")[0]) in self.synapse_names:
syn_edges = self.base.edges[self.base.edges["type"] == syn_name]
inds = np.intersect1d(
self._edges_in_view, syn_edges.index, return_indices=True
)[2]
if len(inds) > 0:
jaxedges[key] = values[inds]

jaxnodes = {} if self.base.jaxnodes is not None else None
jaxedges = {} if self.base.jaxedges is not None else None

mechs = [m._name for m in self.channels + self.synapses if m is not None]
for jax_array, base_jax_array, viewed_inds in zip(
[jaxnodes, jaxedges],
[self.base.jaxnodes, self.base.jaxedges],
[self._nodes_in_view, self._edges_in_view],
):
if base_jax_array is not None and len(viewed_inds) > 0:
for key, values in base_jax_array.items():
mech, mech_inds = self.base._get_mech_inds_of_param_state(key)
if mech is None or mech in mechs:
jax_array[key] = values[
a_intersects_b_at(mech_inds, viewed_inds)
]
jax_array["index"] = np.asarray(viewed_inds)

return jaxnodes, jaxedges

def _set_externals_in_view(self):
Expand Down
4 changes: 2 additions & 2 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from jaxley.modules.base import Module
from jaxley.modules.compartment import Compartment
from jaxley.utils.cell_utils import compute_children_and_parents
from jaxley.utils.cell_utils import compute_children_and_parents, dtype_aware_concat
from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs
from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices

Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)

# Indexing.
self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)
self.nodes = dtype_aware_concat([c.nodes for c in compartment_list])
self._append_params_and_states(self.branch_params, self.branch_states)
self.nodes["global_comp_index"] = np.arange(self.ncomp).tolist()
self.nodes["global_branch_index"] = [0] * self.ncomp
Expand Down
3 changes: 2 additions & 1 deletion jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
compute_levels,
compute_morphology_indices_in_levels,
compute_parents_in_level,
dtype_aware_concat,
)
from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs
from jaxley.utils.solver_utils import (
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(
self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])

# Build nodes. Has to be changed when `.set_ncomp()` is run.
self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)
self.nodes = dtype_aware_concat([c.nodes for c in branch_list])
self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1])
self.nodes["global_branch_index"] = np.repeat(
np.arange(self.total_nbranches), self.ncomp_per_branch
Expand Down
11 changes: 6 additions & 5 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
build_branchpoint_group_inds,
compute_children_and_parents,
compute_current_density,
dtype_aware_concat,
loc_of_index,
merge_cells,
query_states_and_params,
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(
self.total_nbranches = sum(self.nbranches_per_cell)
self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)

self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)
self.nodes = dtype_aware_concat([c.nodes for c in cells])
self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1])
self.nodes["global_branch_index"] = np.repeat(
np.arange(self.total_nbranches), self.ncomp_per_branch
Expand Down Expand Up @@ -267,8 +268,8 @@ def _step_synapse_state(

for i, group in edges.groupby("type_ind"):
synapse = syn_channels[i]
pre_inds = group["global_pre_comp_index"].to_numpy()
post_inds = group["global_post_comp_index"].to_numpy()
pre_inds = group["pre_global_comp_index"].to_numpy()
post_inds = group["post_global_comp_index"].to_numpy()
edge_inds = group.index.to_numpy()

query_syn = lambda d, names: query_states_and_params(d, names, edge_inds)
Expand Down Expand Up @@ -311,8 +312,8 @@ def _synapse_currents(
synapse_current_states = {f"{s._name}_current": zeros for s in syn_channels}
for i, group in edges.groupby("type_ind"):
synapse = syn_channels[i]
pre_inds = group["global_pre_comp_index"].to_numpy()
post_inds = group["global_post_comp_index"].to_numpy()
pre_inds = group["pre_global_comp_index"].to_numpy()
post_inds = group["post_global_comp_index"].to_numpy()
edge_inds = group.index.to_numpy()

query_syn = lambda d, names: query_states_and_params(d, names, edge_inds)
Expand Down
13 changes: 13 additions & 0 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,16 @@ def compute_children_and_parents(
child_belongs_to_branchpoint = remap_to_consecutive(par_inds)
par_inds = np.unique(par_inds)
return par_inds, child_inds, child_belongs_to_branchpoint


def dtype_aware_concat(dfs):
concat_df = pd.concat(dfs, ignore_index=True)
# replace nans with Nones
# this correctly casts float(None) -> NaN, bool(None) -> NaN, etc.
concat_df[concat_df.isna()] = None
for col in concat_df.columns[concat_df.dtypes == "object"]:
for df in dfs:
if col in df.columns:
concat_df[col] = concat_df[col].astype(df[col].dtype)
break # first match is sufficient
return concat_df
Loading

0 comments on commit 2ad0b92

Please sign in to comment.