Skip to content

Commit

Permalink
fix: all tests passing, address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Oct 29, 2024
1 parent 049cdd8 commit 192a5d1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 19 deletions.
56 changes: 41 additions & 15 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def __getattr__(self, key):
else self.select(None)
)
view._set_controlled_by_param(key) # overwrites param set by edge
# Temporary fix for synapse param sharing
# Ensure synapse param sharing works with `edge`
# `edge` will be removed as part of #463
view.edges["local_edge_index"] = np.arange(len(view.edges))
return view

Expand All @@ -229,18 +230,16 @@ def _childviews(self) -> List[str]:
return children

def __getitem__(self, index):
supported_lvls = ["network", "cell", "branch"] # cannot index into comp
"""Lazy indexing of the module."""
supported_parents = ["network", "cell", "branch"] # cannot index into comp

# TODO FROM #447: SHOULD WE ALLOW GROUPVIEW TO BE INDEXED?
# IF YES, UNDER WHICH CONDITIONS?
is_group_view = self._current_view in self.groups
not_group_view = self._current_view not in self.groups
assert (
self._current_view in supported_lvls or is_group_view
), "Lazy indexing is not supported for this View/Module."
self._current_view in supported_parents or not_group_view
), "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof."
index = index if isinstance(index, tuple) else (index,)

module_or_view = self.base if is_group_view else self
child_views = module_or_view._childviews()
child_views = self._childviews()
assert len(index) <= len(child_views), "Too many indices."
view = self
for i, child in zip(index, child_views):
Expand Down Expand Up @@ -307,8 +306,8 @@ def _init_view(self):
"""Init attributes critical for View.
Needs to be called at init of a Module."""
lvl = self.__class__.__name__.lower()
self._current_view = "comp" if lvl == "compartment" else lvl
parent = self.__class__.__name__.lower()
self._current_view = "comp" if parent == "compartment" else parent
self._nodes_in_view = self.nodes.index.to_numpy()
self._edges_in_view = self.edges.index.to_numpy()
self.nodes["controlled_by_param"] = 0
Expand Down Expand Up @@ -2205,6 +2204,13 @@ class View(Module):
allow to target specific parts of a Module, i.e. setting parameters for parts
of a cell.
Almost all methods in View are concerned with updating the attributes of the
base Module, i.e. `self.base`, based on the indices in view. For example,
`_channels_in_view` lists all channels, finds the subset set to `True` in
`self.nodes` (currently in view) and returns the updated list such that we can set
`self.channels = self._channels_in_view()`.
To allow seamless operation on Views and Modules as if they were the same,
the following needs to be ensured:
1. We consider a Module to have everything in view.
Expand Down Expand Up @@ -2316,7 +2322,7 @@ def __init__(
def _set_inds_in_view(
self, pointer: Union[Module, View], nodes: np.ndarray, edges: np.ndarray
):
"""Set nodes and edge indices that are in view."""
"""Update node and edge indices to list only those currently in view."""
# set nodes and edge indices in view
has_node_inds = nodes is not None
has_edge_inds = edges is not None
Expand Down Expand Up @@ -2352,6 +2358,7 @@ def _set_inds_in_view(
self._edges_in_view = edges

def _jax_arrays_in_view(self, pointer: Union[Module, View]):
"""Update jaxnodes/jaxedges to show only those currently in view."""
a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1]
jaxnodes = {} if pointer.jaxnodes is not None else None
if self.jaxnodes is not None:
Expand All @@ -2376,6 +2383,7 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]):
return jaxnodes, jaxedges

def _set_externals_in_view(self):
"""Update external inputs to show only those currently in view."""
self.externals = {}
self.external_inds = {}
for (name, inds), data in zip(
Expand All @@ -2390,7 +2398,17 @@ def _set_externals_in_view(self):
def _filter_trainables(
self, is_viewed: bool = True
) -> Tuple[List[np.ndarray], List[Dict]]:
"""filters the trainables inside and outside of the view
"""Filters the trainables inside and outside of the view.
Trainables are split between `indices_set_by_trainables` and `trainable_params`
and can be shared between mutliple compartments / branches etc, which makes it
difficult to filter them based on the current view w.o. destroying the
original structure.
This method filters `indices_set_by_trainables` for the indices that are
currently in view (or not in view) and returns the corresponding trainable
parameters and indices such that the sharing behavior is preserved as much as
possible.
Args:
is_viewed: Toggles between returning the trainables and inds
Expand Down Expand Up @@ -2425,8 +2443,9 @@ def _filter_trainables(
índices_set_by_trainables_in_view.append(inds[completely_in_view])
partial_inds = inds[partially_in_view][in_view[partially_in_view]]

# the indexing above can lead to inconsistent shapes.
# this is fixed here to return them to the prev shape
# the indexing i.e. `inds[partially_in_view]` reshapes `inds`. Since the shape
# determines how parameters are shared, `inds` has to be returned to its
# original shape.
if inds.shape[0] > 1 and partial_inds.shape != (0,):
partial_inds = partial_inds.reshape(-1, 1)
if inds.shape[1] > 1 and partial_inds.shape != (0,):
Expand All @@ -2443,6 +2462,7 @@ def _filter_trainables(
return indices_set_by_trainables, trainable_params

def _set_trainables_in_view(self):
"""Set `trainable_params` and `indices_set_by_trainables` to show only those in view."""
trainables = self._filter_trainables()

# note for `branch.comp(0).make_trainable("X"); branch.make_trainable("X")`
Expand All @@ -2451,12 +2471,14 @@ def _set_trainables_in_view(self):
self.trainable_params = trainables[1]

def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]:
"""Set channels to show only those in view."""
names = [name._name for name in pointer.channels]
channel_in_view = self.nodes[names].any(axis=0)
channel_in_view = channel_in_view[channel_in_view].index
return [c for c in pointer.channels if c._name in channel_in_view]

def _set_synapses_in_view(self, pointer: Union[Module, View]):
"""Set synapses to show only those in view."""
viewed_synapses = []
viewed_params = []
viewed_states = []
Expand All @@ -2478,6 +2500,9 @@ def _nbranches_per_cell_in_view(self) -> np.ndarray:
return cell_nodes["global_branch_index"].nunique().to_list()

def _xyzr_in_view(self) -> List[np.ndarray]:
"""Return xyzr coordinates of every branch that is in `_branches_in_view`.
If a branch is not completely in view, the coordinates are interpolated."""
xyzr = [self.base.xyzr[i] for i in self._branches_in_view].copy()

# Currently viewing with `.loc` will show the closest compartment
Expand Down Expand Up @@ -2527,6 +2552,7 @@ def _comps_in_view(self) -> np.ndarray:

@property
def _branch_edges_in_view(self) -> np.ndarray:
"""Lists the global branch edge indices which are currently part of the view."""
incl_branches = self.nodes["global_branch_index"].unique()
pre = self.base.branch_edges["parent_branch_index"].isin(incl_branches)
post = self.base.branch_edges["child_branch_index"].isin(incl_branches)
Expand Down
26 changes: 22 additions & 4 deletions tests/test_viewing.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,18 @@ def test_solve_indexer():
# make sure all attrs in module also have a corresponding attr in view
@pytest.mark.parametrize("module", [comp, branch, cell, net])
def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network):
"""Check if all attributes of Module have a corresponding attribute in View.
To ensure that View behaves like a Module as much as possible, View should support
all attributes of Module. This test checks if all attributes of Module have a
corresponding attribute in View. Also checks if the types of the attributes match.
"""
# attributes of Module that do not have to exist in View
exceptions = ["view"]

# TODO: Types are inconsistent between different Modules
exceptions += ["cumsum_nbranches"]

# TODO FROM #447: should be added to View in the future
exceptions += [
"_internal_node_inds",
Expand All @@ -278,7 +288,6 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network):
"cumsum_nbranchpoints_per_cell",
"_cumsum_nseg_per_cell",
] # for network
exceptions += ["cumsum_nbranches"] # TODO: take care of this

for name, attr in module.__dict__.items():
if name not in exceptions:
Expand All @@ -298,7 +307,8 @@ def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network):


@pytest.mark.parametrize("module", [comp, branch, cell, net])
def test_different_index_types(module):
def test_view_supported_index_types(module):
"""Check if different ways to index into Modules/Views work correctly."""
# test int, range, slice, list, np.array, pd.Index
index_types = [
0,
Expand All @@ -308,6 +318,7 @@ def test_different_index_types(module):
np.array([0, 1, 2]),
pd.Index([0, 1, 2]),
]
# `_reformat_index` should always return a np.ndarray
for index in index_types:
assert isinstance(
module._reformat_index(index), np.ndarray
Expand All @@ -321,6 +332,7 @@ def test_different_index_types(module):


def test_select():
"""Ensure `select` works correctly and returns expected View of Modules."""
comp = jx.Compartment()
branch = jx.Branch([comp] * 3)
cell = jx.Cell([branch] * 3, parents=[-1, 0, 0])
Expand Down Expand Up @@ -362,6 +374,7 @@ def test_select():


def test_viewing():
"""Test that the View object is working correctly."""
comp = jx.Compartment()
branch = jx.Branch([comp] * 3)
cell = jx.Cell([branch] * 3, parents=[-1, 0, 0])
Expand Down Expand Up @@ -415,6 +428,7 @@ def test_viewing():


def test_scope():
"""Ensure scope has the intended effect for Modules and Views."""
comp = jx.Compartment()
branch = jx.Branch([comp] * 3)
cell = jx.Cell([branch] * 3, parents=[-1, 0, 0])
Expand Down Expand Up @@ -448,6 +462,7 @@ def test_scope():


def test_context_manager():
"""Test that context manager works correctly for Module."""
comp = jx.Compartment()
branch = jx.Branch([comp] * 3)
cell = jx.Cell([branch] * 3, parents=[-1, 0, 0])
Expand All @@ -472,6 +487,7 @@ def test_context_manager():


def test_iter():
"""Test that __iter__ works correctly for all modules."""
comp = jx.Compartment()
branch1 = jx.Branch([comp] * 2)
branch2 = jx.Branch([comp] * 3)
Expand Down Expand Up @@ -531,6 +547,7 @@ def test_iter():


def test_synapse_and_channel_filtering():
"""Test that synapses and channels are filtered correctly by View."""
comp = jx.Compartment()
branch = jx.Branch([comp] * 3)
cell = jx.Cell([branch] * 3, parents=[-1, 0, 0])
Expand All @@ -550,15 +567,16 @@ def test_synapse_and_channel_filtering():
edges_control_param1 = edges1.pop("controlled_by_param")
edges_control_param2 = edges2.pop("controlled_by_param")

assert np.all(nodes1 == nodes2)
# convert to dict so order of cols and index dont matter for __eq__
assert nodes1.to_dict() == nodes2.to_dict()
assert np.all(nodes_control_param1 == 0)
assert np.all(nodes_control_param2 == nodes2["global_cell_index"])

assert np.all(edges1 == edges2)


def test_view_equals_module():
# test that module behaves the same as view for important attributes
"""Test that View behaves the same as Module for important attrs and methods."""
comp = jx.Compartment()
branch = jx.Branch([comp] * 3)

Expand Down

0 comments on commit 192a5d1

Please sign in to comment.