Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualize cells that were not read from SWC #192

Merged
merged 8 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 84 additions & 30 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from jaxley.channels import Channel
from jaxley.solver_voltage import step_voltage_explicit, step_voltage_implicit
from jaxley.synapses import Synapse
from jaxley.utils.plot_utils import plot_morph, plot_swc
from jaxley.utils.cell_utils import (
_compute_index_of_child,
_compute_num_children,
compute_levels,
)
from jaxley.utils.plot_utils import plot_morph


class Module(ABC):
Expand Down Expand Up @@ -636,7 +641,6 @@ def get_external_input(

def vis(
self,
detail: str = "full",
ax=None,
col: str = "k",
dims: Tuple[int] = (0, 1),
Expand All @@ -645,49 +649,97 @@ def vis(
"""Visualize the module.

Args:
detail: Either of [sticks, full]. `sticks` visualizes all branches of every
neuron, but draws branches as straight lines. `full` plots the full
morphology of every neuron, as read from the SWC file.
ax: An axis into which to plot.
col: The color for all branches.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
"""
return self._vis(
detail=detail,
dims=dims,
col=col,
ax=ax,
view=self.nodes,
morph_plot_kwargs=morph_plot_kwargs,
)

def _vis(self, detail, ax, col, dims, view, morph_plot_kwargs):
def _vis(self, ax, col, dims, view, morph_plot_kwargs):
branches_inds = view["branch_index"].to_numpy()
coords = [self.xyzr[branch_ind] for branch_ind in branches_inds]

if detail == "full":
assert self.xyzr, "no coordinates available, use `vis(detail='point')`."
ax = plot_swc(
coords,
dims=dims,
col=col,
ax=ax,
morph_plot_kwargs=morph_plot_kwargs,
)
# elif detail == "sticks":
# fig, ax = plot_morph(
# cell=self,
# col=col,
# max_y_multiplier=5.0,
# min_y_multiplier=0.5,
# ax=ax,
# )
else:
raise ValueError("`detail must be in {point, full}.")
coords = []
for branch_ind in branches_inds:
assert not np.any(
np.isnan(self.xyzr[branch_ind][:, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."
coords.append(self.xyzr[branch_ind])

ax = plot_morph(
coords,
dims=dims,
col=col,
ax=ax,
morph_plot_kwargs=morph_plot_kwargs,
)

return ax

def compute_xyz(self):
"""Return xyz coordinates of every branch, based on the branch length."""
max_y_multiplier = 5.0
min_y_multiplier = 0.5

parents = self.comb_parents
num_children = _compute_num_children(parents)
index_of_child = _compute_index_of_child(parents)
levels = compute_levels(parents)

# Extract branch.
inds_branch = self.nodes.groupby("branch_index")["comp_index"].apply(list)
branch_lens = [
np.sum(self.params["length"][np.asarray(i)]) for i in inds_branch
]
endpoints = []

# Different levels will get a different "angle" at which the children emerge from
# the parents. This angle is defined by the `y_offset_multiplier`. This value
# defines the range between y-location of the first and of the last child of a
# parent.
y_offset_multiplier = np.linspace(
max_y_multiplier, min_y_multiplier, np.max(levels) + 1
)

for b in range(self.total_nbranches):
if parents[b] > -1:
start_point = endpoints[parents[b]]
num_children_of_parent = num_children[parents[b]]
y_offset = (
((index_of_child[b] / (num_children_of_parent - 1))) - 0.5
) * y_offset_multiplier[levels[b]]
else:
start_point = [0, 0]
y_offset = 0.0

len_of_path = np.sqrt(y_offset**2 + 1.0)

end_point = [
start_point[0] + branch_lens[b] / len_of_path * 1.0,
start_point[1] + branch_lens[b] / len_of_path * y_offset,
]
endpoints.append(end_point)

self.xyzr[b][:, :2] = np.asarray([start_point, end_point])

def move(self, x: float = 0.0, y: float = 0.0, z: float = 0.0):
"""Move cells or networks in the (x, y, z) plane."""
self._move(x, y, z, self.nodes)

def _move(self, x: float, y: float, z: float, view):
# Need to cast to set because this will return one columnn per compartment,
# not one column per branch.
indizes = set(view["branch_index"].to_numpy().tolist())
for i in indizes:
self.xyzr[i][:, 0] += x
self.xyzr[i][:, 1] += y
self.xyzr[i][:, 2] += z


class View:
"""View of a `Module`."""
Expand Down Expand Up @@ -781,22 +833,24 @@ def add_to_group(self, group_name: str):

def vis(
self,
detail: str = "full",
ax=None,
col="k",
dims=(0, 1),
morph_plot_kwargs: Dict = {},
):
nodes = self.set_global_index_and_index(self.view)
return self.pointer._vis(
detail=detail,
ax=ax,
col=col,
dims=dims,
view=nodes,
morph_plot_kwargs=morph_plot_kwargs,
)

def move(self, x: float = 0.0, y: float = 0.0, z: float = 0.0):
nodes = self.set_global_index_and_index(self.view)
self.pointer._move(x, y, z, nodes)

def adjust_view(self, key: str, index: float):
"""Update view."""
if isinstance(index, int) or isinstance(index, np.int64):
Expand Down
12 changes: 11 additions & 1 deletion jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,17 @@ def __init__(
branch_list = [branches for _ in range(len(parents))]
else:
branch_list = branches
self.xyzr = xyzr

if xyzr is not None:
assert len(xyzr) == len(parents)
self.xyzr = xyzr
else:
# For every branch (`len(parents)`), we have a start and end point (`2`) and
# a (x,y,z,r) coordinate for each of them (`4`).
# Since `xyzr` is only inspected at `.vis()` and because it depends on the
# (potentially learned) length of every compartment, we only populate
# self.xyzr at `.vis()`.
self.xyzr = [float("NaN") * np.zeros((2, 4)) for _ in range(len(parents))]

self._append_to_params_and_state(branch_list)
for branch in branch_list:
Expand Down
63 changes: 48 additions & 15 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
self._append_to_params_and_state(cells)
for cell in cells:
self._append_to_channel_params_and_state(cell)
self.xyzr += cell.xyzr
self.xyzr += deepcopy(cell.xyzr)
self._append_synapses_to_params_and_state(connectivities)

self.cells = cells
Expand Down Expand Up @@ -203,6 +203,8 @@ def init_conds(self, params):
def init_syns(self):
global_pre_comp_inds = []
global_post_comp_inds = []
global_pre_branch_inds = []
global_post_branch_inds = []
pre_locs = []
post_locs = []
pre_branch_inds = []
Expand All @@ -220,6 +222,18 @@ def init_syns(self):
global_post_comp_inds.append(
self.cumsum_nbranches[post_cell_inds_] * self.nseg + post_inds
)
global_pre_branch_inds.append(
[
self.cumsum_nbranches[c.pre_cell_ind] + c.pre_branch_ind
for c in connectivity.conns
]
)
global_post_branch_inds.append(
[
self.cumsum_nbranches[c.post_cell_ind] + c.post_branch_ind
for c in connectivity.conns
]
)
# Local compartment inds.
pre_locs.append(np.asarray([c.pre_loc for c in connectivity.conns]))
post_locs.append(np.asarray([c.post_loc for c in connectivity.conns]))
Expand All @@ -246,6 +260,8 @@ def init_syns(self):
"type_ind",
"global_pre_comp_index",
"global_post_comp_index",
"global_pre_branch_index",
"global_post_branch_index",
]
)
for i, connectivity in enumerate(self.connectivities):
Expand All @@ -264,6 +280,8 @@ def init_syns(self):
type_ind=i,
global_pre_comp_index=global_pre_comp_inds[i],
global_post_comp_index=global_post_comp_inds[i],
global_pre_branch_index=global_pre_branch_inds[i],
global_post_branch_index=global_post_branch_inds[i],
)
),
],
Expand Down Expand Up @@ -351,9 +369,8 @@ def vis(
nx.draw(graph, pos, with_labels=True)
else:
nx.draw(graph, with_labels=True)
else:
elif detail == "full":
ax = self._vis(
detail=detail,
dims=dims,
col=col,
ax=ax,
Expand All @@ -363,22 +380,36 @@ def vis(

pre_locs = self.syn_edges["pre_locs"].to_numpy()
post_locs = self.syn_edges["post_locs"].to_numpy()
pre_branch = self.syn_edges["pre_branch_index"].to_numpy()
post_branch = self.syn_edges["post_branch_index"].to_numpy()
pre_cell = self.syn_edges["pre_cell_index"].to_numpy()
post_cell = self.syn_edges["post_cell_index"].to_numpy()
pre_branch = self.syn_edges["global_pre_branch_index"].to_numpy()
post_branch = self.syn_edges["global_post_branch_index"].to_numpy()

dims_np = np.asarray(dims)

for pre_loc, post_loc, pre_b, post_b, pre_c, post_c in zip(
pre_locs, post_locs, pre_branch, post_branch, pre_cell, post_cell
for pre_loc, post_loc, pre_b, post_b in zip(
pre_locs, post_locs, pre_branch, post_branch
):
pre_coord = self.cells[pre_c].xyzr[pre_b]
middle_ind = int((len(pre_coord) - 1) * pre_loc)
pre_coord = pre_coord[middle_ind]
post_coord = self.cells[post_c].xyzr[post_b]
middle_ind = int((len(post_coord) - 1) * post_loc)
post_coord = post_coord[middle_ind]
pre_coord = self.xyzr[pre_b]
if len(pre_coord) == 2:
# If only start and end point of a branch are traced, perform a
# linear interpolation to get the synpase location.
pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc
else:
# If densely traced, use intermediate trace values for synapse loc.
middle_ind = int((len(pre_coord) - 1) * pre_loc)
pre_coord = pre_coord[middle_ind]

post_coord = self.xyzr[post_b]
if len(post_coord) == 2:
# If only start and end point of a branch are traced, perform a
# linear interpolation to get the synpase location.
post_coord = (
post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc
)
else:
# If densely traced, use intermediate trace values for synapse loc.
middle_ind = int((len(post_coord) - 1) * post_loc)
post_coord = post_coord[middle_ind]

coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T
ax.plot(
coords[0],
Expand All @@ -393,6 +424,8 @@ def vis(
c=synapse_col,
**synapse_scatter_kwargs,
)
else:
raise ValueError("detail must be in {full, point}.")

return ax

Expand Down
1 change: 0 additions & 1 deletion jaxley/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from jaxley.utils.plot_utils import plot_morph, plot_swc
Loading