diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index b70db9c6..50f73ced 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -667,7 +667,7 @@ def _vis(self, ax, col, dims, view, morph_plot_kwargs): coords = [self.xyzr[branch_ind] for branch_ind in branches_inds] assert not np.any( - np.isnan(self.xyzr[:, :, dims]) + np.isnan(np.asarray(self.xyzr)[:, :, dims]) ), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`." ax = plot_swc( coords, @@ -723,7 +723,7 @@ def compute_xyz(self): ] endpoints.append(end_point) - self.xyzr[b, :, :2] = np.asarray([start_point, end_point]) + self.xyzr[b][:, :2] = np.asarray([start_point, end_point]) class View: diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 2d6aa922..5727d413 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -57,7 +57,7 @@ def __init__( # Since `xyzr` is only inspected at `.vis()` and because it depends on the # (potentially learned) length of every compartment, we only populate # self.xyzr at `.vis()`. - self.xyzr = float("NaN") * np.zeros((len(parents), 2, 4)) + self.xyzr = [float("NaN") * np.zeros((2, 4)) for _ in range(len(parents))] self._append_to_params_and_state(branch_list) for branch in branch_list: diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index d41984f4..8a391409 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -392,8 +392,8 @@ def vis( c=synapse_col, **synapse_scatter_kwargs, ) - else: - raise ValueError("detail must be in {full, point}.") + else: + raise ValueError("detail must be in {full, point}.") return ax