Skip to content

Commit

Permalink
Bugfix for GroupView
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 19, 2023
1 parent c1f29b7 commit 9d9bafb
Show file tree
Hide file tree
Showing 8 changed files with 651 additions and 320 deletions.
17 changes: 9 additions & 8 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
self.total_nbranches: int = 0
self.nbranches_per_cell: List[int] = None

self.group_views = {}
self.group_nodes = {}

self.nodes: Optional[pd.DataFrame] = None

Expand Down Expand Up @@ -125,10 +125,10 @@ def to_jax(self):
edges = self.edges.to_dict(orient="list")
for i, synapse in enumerate(self.synapses):
for key in synapse.synapse_params:
condition = jnp.asarray(edges["type_ind"]) == i
self.jaxedges[key] = jnp.asarray(edges[key])[condition]
condition = np.asarray(edges["type_ind"]) == i
self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])
for key in synapse.synapse_states:
self.jaxedges[key] = jnp.asarray(edges[key])[condition]
self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])

def show(
self,
Expand Down Expand Up @@ -302,9 +302,9 @@ def add_to_group(self, group_name):
raise ValueError("`add_to_group()` makes no sense for an entire module.")

def _add_to_group(self, group_name, view):
if group_name in self.group_views:
view = pd.concat([self.group_views[group_name].view, view])
self.group_views[group_name] = GroupView(self, view)
if group_name in self.group_nodes.keys():
view = pd.concat([self.group_nodes[group_name], view])
self.group_nodes[group_name] = view

def get_parameters(self):
"""Get all trainable parameters."""
Expand Down Expand Up @@ -712,7 +712,8 @@ def show(
"global_cell_index",
"controlled_by_param",
]:
view = view.drop(name, axis=1)
if name in view.columns:
view = view.drop(name, axis=1)
return view

def set_global_index_and_index(self, nodes):
Expand Down
5 changes: 3 additions & 2 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def __getattr__(self, key):
view["global_branch_index"] = view["branch_index"]
view["global_cell_index"] = view["cell_index"]
return CompartmentView(self, view)
elif key in self.group_views:
return self.group_views[key]
elif key in self.group_nodes:
inds = self.group_nodes[key].index.values
return GroupView(self, self.nodes.loc[inds])
else:
raise KeyError(f"Key {key} not recognized.")

Expand Down
7 changes: 4 additions & 3 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax import vmap
from jax.lax import ScatterDimensionNumbers, scatter_add

from jaxley.modules.base import Module, View
from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.branch import Branch, BranchView, Compartment
from jaxley.utils.cell_utils import (
compute_branches_in_level,
Expand Down Expand Up @@ -102,8 +102,9 @@ def __getattr__(self, key):
view["global_branch_index"] = view["branch_index"]
view["global_cell_index"] = view["cell_index"]
return BranchView(self, view)
elif key in self.group_views:
return self.group_views[key]
elif key in self.group_nodes:
inds = self.group_nodes[key].index.values
return GroupView(self, self.nodes.loc[inds])
else:
raise KeyError(f"Key {key} not recognized.")

Expand Down
7 changes: 4 additions & 3 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax import vmap

from jaxley.connection import Connectivity
from jaxley.modules.base import Module, View
from jaxley.modules.base import GroupView, Module, View
from jaxley.modules.branch import Branch
from jaxley.modules.cell import Cell, CellView
from jaxley.utils.cell_utils import merge_cells
Expand Down Expand Up @@ -96,8 +96,9 @@ def __getattr__(self, key):
elif key in self.synapse_names:
type_index = self.synapse_names.index(key)
return SynapseView(self, self.edges, key, self.synapses[type_index])
elif key in self.group_views:
return self.group_views[key]
elif key in self.group_nodes:
inds = self.group_nodes[key].index.values
return GroupView(self, self.nodes.loc[inds])
else:
raise KeyError(f"Key {key} not recognized.")

Expand Down
59 changes: 19 additions & 40 deletions tutorials/01_small_network.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 9d9bafb

Please sign in to comment.