Skip to content

Commit

Permalink
.vis() works for networks
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 8, 2023
1 parent 813ea59 commit 9dbdfd1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def _vis(self, ax, col, dims, view, morph_plot_kwargs):
coords = [self.xyzr[branch_ind] for branch_ind in branches_inds]

assert not np.any(
np.isnan(self.xyzr[:, :, dims])
np.isnan(np.asarray(self.xyzr)[:, :, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."
ax = plot_swc(
coords,
Expand Down Expand Up @@ -723,7 +723,7 @@ def compute_xyz(self):
]
endpoints.append(end_point)

self.xyzr[b, :, :2] = np.asarray([start_point, end_point])
self.xyzr[b][:, :2] = np.asarray([start_point, end_point])


class View:
Expand Down
2 changes: 1 addition & 1 deletion jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
# Since `xyzr` is only inspected at `.vis()` and because it depends on the
# (potentially learned) length of every compartment, we only populate
# self.xyzr at `.vis()`.
self.xyzr = float("NaN") * np.zeros((len(parents), 2, 4))
self.xyzr = [float("NaN") * np.zeros((2, 4)) for _ in range(len(parents))]

self._append_to_params_and_state(branch_list)
for branch in branch_list:
Expand Down
4 changes: 2 additions & 2 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ def vis(
c=synapse_col,
**synapse_scatter_kwargs,
)
else:
raise ValueError("detail must be in {full, point}.")
else:
raise ValueError("detail must be in {full, point}.")

return ax

Expand Down

0 comments on commit 9dbdfd1

Please sign in to comment.