From c7a04c3bf7bb4efe89ad2c9a000d9d357b1125b3 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 10 Apr 2024 22:05:40 +0200 Subject: [PATCH] Fix __iter__ and .shape --- jaxley/modules/base.py | 29 +++-------------------------- jaxley/modules/branch.py | 14 ++++++++++++-- jaxley/modules/cell.py | 13 ++++++++++++- jaxley/modules/compartment.py | 11 ++++++++++- jaxley/modules/network.py | 6 ++++++ jaxley/utils/cell_utils.py | 24 ++++++++++++++++++++++++ 6 files changed, 67 insertions(+), 30 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 4d41eb70..e663a8eb 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -22,6 +22,7 @@ flip_comp_indices, interpolate_xyz, loc_of_index, + get_local_indices ) from jaxley.utils.plot_utils import plot_morph @@ -1249,29 +1250,6 @@ def adjust_view(self, key: str, index: Union[int, str, list, range, slice]): self.view["controlled_by_param"] -= self.view["controlled_by_param"].iloc[0] return self - def _get_local_indices(self): - """Computes local from global indices. - - #cell_index, branch_index, comp_index - 0, 0, 0 --> 0, 0, 0 # 1st compartment of 1st branch of 1st cell - 0, 0, 1 --> 0, 0, 1 # 2nd compartment of 1st branch of 1st cell - 0, 1, 2 --> 0, 1, 0 # 1st compartment of 2nd branch of 1st cell - 0, 1, 3 --> 0, 1, 1 # 2nd compartment of 2nd branch of 1st cell - 1, 2, 4 --> 1, 0, 0 # 1st compartment of 1st branch of 2nd cell - 1, 2, 5 --> 1, 0, 1 # 2nd compartment of 1st branch of 2nd cell - 1, 3, 6 --> 1, 1, 0 # 1st compartment of 2nd branch of 2nd cell - 1, 3, 7 --> 1, 1, 1 # 2nd compartment of 2nd branch of 2nd cell - """ - - def reindex_a_by_b(df, a, b): - df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1 - return df - - idcs = self.view[["cell_index", "branch_index", "comp_index"]] - idcs = reindex_a_by_b(idcs, "branch_index", "cell_index") - idcs = reindex_a_by_b(idcs, "comp_index", ["cell_index", "branch_index"]) - return idcs - def _childview(self, index: Union[int, str, list, range, slice]): """Return the child view of the current view. @@ -1293,7 +1271,7 @@ def __getitem__(self, index): return self._childview(index) def __iter__(self): - for i in range(self.shape[0]): + for i in range(self.shape[1]): yield self[i] def rotate(self, degrees: float, rotation_axis: str = "xy"): @@ -1309,8 +1287,7 @@ def rotate(self, degrees: float, rotation_axis: str = "xy"): @property def shape(self): - local_idcs = self._get_local_indices() - return tuple(local_idcs.nunique()) + raise NotImplementedError def _append_multiple_synapses( self, pre_rows: pd.DataFrame, post_rows: pd.DataFrame, synapse_type: Synapse diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index c006765d..cbcf5b97 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -7,7 +7,7 @@ from jaxley.modules.base import GroupView, Module, View from jaxley.modules.compartment import Compartment, CompartmentView -from jaxley.utils.cell_utils import compute_coupling_cond +from jaxley.utils.cell_utils import compute_coupling_cond, get_local_indices class Branch(Module): @@ -82,6 +82,11 @@ def __getattr__(self, key): else: raise KeyError(f"Key {key} not recognized.") + @property + def shape(self): + local_idcs = get_local_indices(self.nodes) + return tuple(local_idcs.nunique())[2:] + def init_conds(self, params): conds = self.init_branch_conds( params["axial_resistivity"], params["radius"], params["length"], self.nseg @@ -134,7 +139,7 @@ def __init__(self, pointer, view): super().__init__(pointer, view) def __call__(self, index: float): - local_idcs = self._get_local_indices() + local_idcs = get_local_indices(self.view) self.view[local_idcs.columns] = ( local_idcs # set indexes locally. enables net[0:2,0:2] ) @@ -146,3 +151,8 @@ def __getattr__(self, key): assert key in ["comp", "loc"] compview = CompartmentView(self.pointer, self.view) return compview if key == "comp" else compview.loc + + @property + def shape(self): + local_idcs = get_local_indices(self.view) + return tuple(local_idcs.nunique())[1:] \ No newline at end of file diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 7f901bd4..d78a4d24 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -16,6 +16,7 @@ compute_coupling_cond, compute_levels, loc_of_index, + get_local_indices ) from jaxley.utils.swc import swc_to_jaxley @@ -115,6 +116,11 @@ def __getattr__(self, key): else: raise KeyError(f"Key {key} not recognized.") + @property + def shape(self): + local_idcs = get_local_indices(self.nodes) + return tuple(local_idcs.nunique())[1:] + def init_morph(self): """Initialize morphology.""" parents = self.comb_parents @@ -236,7 +242,7 @@ def __init__(self, pointer, view): super().__init__(pointer, view) def __call__(self, index: float): - local_idcs = self._get_local_indices() + local_idcs = get_local_indices(self.view) self.view[local_idcs.columns] = ( local_idcs # set indexes locally. enables net[0:2,0:2] ) @@ -249,6 +255,11 @@ def __getattr__(self, key): assert key == "branch" return BranchView(self.pointer, self.view) + @property + def shape(self): + local_idcs = get_local_indices(self.view) + return tuple(local_idcs.nunique()) + def rotate(self, degrees: float, rotation_axis: str = "xy"): """Rotate jaxley modules clockwise. Used only for visualization. diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index 6504215b..42b9eaf9 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -6,7 +6,7 @@ from matplotlib.axes import Axes from jaxley.modules.base import Module, View -from jaxley.utils.cell_utils import index_of_loc, interpolate_xyz, loc_of_index +from jaxley.utils.cell_utils import index_of_loc, interpolate_xyz, loc_of_index, get_local_indices class Compartment(Module): @@ -45,6 +45,10 @@ def __init__(self): # Coordinates. self.xyzr = [float("NaN") * np.zeros((2, 4))] + @property + def shape(self): + return () + def init_conds(self, params): cond_params = { "branch_conds_fwd": jnp.asarray([]), @@ -72,6 +76,11 @@ def __call__(self, index: int): "'CompartmentView' object has no attribute 'comp' or 'loc'." ) + @property + def shape(self): + local_idcs = get_local_indices(self.view) + return tuple(local_idcs.nunique())[2:] + def loc(self, loc: float): if loc != "all": assert ( diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index cc1471e0..c7ed9b96 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -18,6 +18,7 @@ convert_point_process_to_distributed, flip_comp_indices, merge_cells, + get_local_indices, ) from jaxley.utils.syn_utils import gather_synapes, prepare_syn @@ -111,6 +112,11 @@ def __getattr__(self, key): else: raise KeyError(f"Key {key} not recognized.") + @property + def shape(self): + local_idcs = get_local_indices(self.nodes) + return tuple(local_idcs.nunique()) + def init_morph(self): self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells] self.total_nbranches = sum(self.nbranches_per_cell) diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index cb470d39..699899df 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -242,3 +242,27 @@ def convert_point_process_to_distributed( area = 2 * pi * radius * length current /= area # nA / um^2 return current * 100_000 # Convert (nA / um^2) to (uA / cm^2) + + +def get_local_indices(view): + """Computes local from global indices. + + #cell_index, branch_index, comp_index + 0, 0, 0 --> 0, 0, 0 # 1st compartment of 1st branch of 1st cell + 0, 0, 1 --> 0, 0, 1 # 2nd compartment of 1st branch of 1st cell + 0, 1, 2 --> 0, 1, 0 # 1st compartment of 2nd branch of 1st cell + 0, 1, 3 --> 0, 1, 1 # 2nd compartment of 2nd branch of 1st cell + 1, 2, 4 --> 1, 0, 0 # 1st compartment of 1st branch of 2nd cell + 1, 2, 5 --> 1, 0, 1 # 2nd compartment of 1st branch of 2nd cell + 1, 3, 6 --> 1, 1, 0 # 1st compartment of 2nd branch of 2nd cell + 1, 3, 7 --> 1, 1, 1 # 2nd compartment of 2nd branch of 2nd cell + """ + + def reindex_a_by_b(df, a, b): + df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1 + return df + + idcs = view[["cell_index", "branch_index", "comp_index"]] + idcs = reindex_a_by_b(idcs, "branch_index", "cell_index") + idcs = reindex_a_by_b(idcs, "comp_index", ["cell_index", "branch_index"]) + return idcs \ No newline at end of file