diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 50f73ced..62d96f4a 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -725,6 +725,19 @@ def compute_xyz(self): self.xyzr[b][:, :2] = np.asarray([start_point, end_point]) + def move(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): + self._move(x, y, z, self.nodes) + + def _move(self, x: float, y: float, z: float, view): + # Need to cast to set because this will return one columnn per compartment, + # not one column per branch. + indizes = set(view["branch_index"].to_numpy().tolist()) + for i in indizes: + self.xyzr[i][:, 0] += x + self.xyzr[i][:, 1] += y + self.xyzr[i][:, 2] += z + + class View: """View of a `Module`.""" @@ -832,6 +845,10 @@ def vis( morph_plot_kwargs=morph_plot_kwargs, ) + def move(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): + nodes = self.set_global_index_and_index(self.view) + self.pointer._move(x, y, z, nodes) + def adjust_view(self, key: str, index: float): """Update view.""" if isinstance(index, int) or isinstance(index, np.int64): diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 8a391409..66483447 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -39,7 +39,7 @@ def __init__( self._append_to_params_and_state(cells) for cell in cells: self._append_to_channel_params_and_state(cell) - self.xyzr += cell.xyzr + self.xyzr += deepcopy(cell.xyzr) self._append_synapses_to_params_and_state(connectivities) self.cells = cells