Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cell visualization #182

Merged
merged 6 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compute_coupling_cond,
compute_levels,
)
from jaxley.utils.plot_utils import plot_morph, plot_swc
from jaxley.utils.swc import swc_to_jaxley


Expand All @@ -23,7 +24,21 @@ class Cell(Module):
cell_params: Dict = {}
cell_states: Dict = {}

def __init__(self, branches: Union[Branch, List[Branch]], parents: List):
def __init__(
self,
branches: Union[Branch, List[Branch]],
parents: List,
xyzr: Optional[List[np.ndarray]] = None,
):
"""Initialize a cell.

Args:
branches:
parents:
xyzr: For every branch, the x, y, and z coordinates and the radius at the
traced coordinates. Note that this is the full tracing (from SWC), not
the stick representation coordinates.
"""
super().__init__()
assert isinstance(branches, Branch) or len(parents) == len(
branches
Expand All @@ -33,6 +48,7 @@ def __init__(self, branches: Union[Branch, List[Branch]], parents: List):
branch_list = [branches for _ in range(len(parents))]
else:
branch_list = branches
self.xyzr = xyzr

self._append_to_params_and_state(branch_list)
for branch in branch_list:
Expand Down Expand Up @@ -214,6 +230,54 @@ def update_summed_coupling_conds(
)
return summed_conds

def vis(
self,
detail: str = "full",
figsize=(4, 4),
dims=(0, 1),
cols="k",
highlight_branch_inds=[],
max_y_multiplier: float = 5.0,
min_y_multiplier: float = 0.5,
) -> None:
"""Visualize the network.

Args:
detail: Either of [sticks, full]. `sticks` visualizes all branches of every
neuron, but draws branches as straight lines. `full` plots the full
morphology of every neuron, as read from the SWC file.
layers: Allows to plot the network in layers. Should provide the number of
neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input
neurons, 10 hidden layer neurons, and 1 output neuron.
options: Plotting options passed to `NetworkX.draw()`.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
cols: The color for all branches except the highlighted ones.
highlight_branch_inds: Branch indices that will be highlighted.
"""
if detail == "sticks":
fig, ax = plot_morph(
cell=self,
figsize=figsize,
cols=cols,
highlight_branch_inds=highlight_branch_inds,
max_y_multiplier=max_y_multiplier,
min_y_multiplier=min_y_multiplier,
)
elif detail == "full":
assert self.xyzr is not None, "no coordinates, use `vis(detail='sticks')`."
fig, ax = plot_swc(
self.xyzr,
figsize=figsize,
dims=dims,
cols=cols,
highlight_branch_inds=highlight_branch_inds,
)
else:
raise ValueError("`detail must be in {sticks, full}.")

return fig, ax


class CellView(View):
"""CellView."""
Expand All @@ -239,7 +303,7 @@ def read_swc(
min_radius: Optional[float] = None,
):
"""Reads SWC file into a `jx.Cell`."""
parents, pathlengths, radius_fns, _ = swc_to_jaxley(
parents, pathlengths, radius_fns, _, coords_of_branches = swc_to_jaxley(
fname, max_branch_len=max_branch_len, sort=True, num_lines=None
)
nbranches = len(parents)
Expand All @@ -249,7 +313,9 @@ def read_swc(

comp = Compartment().initialize()
branch = Branch([comp for _ in range(nseg)]).initialize()
cell = Cell([branch for _ in range(nbranches)], parents=parents)
cell = Cell(
[branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches
)

radiuses = np.flip(
np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1
Expand Down
21 changes: 5 additions & 16 deletions jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
_compute_num_children,
compute_levels,
)
from jaxley.utils.swc import _build_parents, _split_into_branches_and_sort

highlight_cols = [
"#1f78b4",
Expand Down Expand Up @@ -114,8 +113,7 @@ def plot_morph(


def plot_swc(
fname,
max_branch_len: float = 100.0,
xyzr,
figsize=(4, 4),
dims=(0, 1),
cols="k",
Expand All @@ -129,30 +127,21 @@ def plot_swc(
cols: The color for all branches except the highlighted ones.
highlight_branch_inds: Branch indices that will be highlighted.
"""
content = np.loadtxt(fname)
sorted_branches, _ = _split_into_branches_and_sort(
content, max_branch_len=max_branch_len, sort=True
)
parents = _build_parents(sorted_branches)
if np.sum(np.asarray(parents) == -1) > 1.0:
sorted_branches = [[0]] + sorted_branches
cols = [cols] * len(sorted_branches)
cols = [cols] * len(xyzr)

counter_highlight_branches = 0
lines = []

fig, ax = plt.subplots(1, 1, figsize=figsize)
for i, branch in enumerate(sorted_branches):
coords_of_branch = content[np.asarray(branch) - 1, 2:5]
coords_of_branch = coords_of_branch[:, dims]

for i, coords_of_branch in enumerate(xyzr):
coords_to_plot = coords_of_branch[:, dims]
col = cols[i]
if i in highlight_branch_inds:
col = highlight_cols[counter_highlight_branches % len(highlight_cols)]
counter_highlight_branches += 1

(line,) = ax.plot(
coords_of_branch[:, 0], coords_of_branch[:, 1], c=col, label=f"ind {i}"
coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, label=f"ind {i}"
)
if i in highlight_branch_inds:
lines.append(line)
Expand Down
8 changes: 7 additions & 1 deletion jaxley/utils/swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ def swc_to_jaxley(
parents = parents.tolist()
pathlengths = [0.1] + pathlengths
radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns
sorted_branches = [[0]] + sorted_branches

return parents, pathlengths, radius_fns, types
all_coords_of_branches = []
for i, branch in enumerate(sorted_branches):
coords_of_branch = content[np.asarray(branch) - 1, 2:5]
all_coords_of_branches.append(coords_of_branch)

return parents, pathlengths, radius_fns, types, all_coords_of_branches


def _split_into_branches_and_sort(
Expand Down