diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index c8832760..e0a3787c 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -14,6 +14,7 @@ compute_coupling_cond, compute_levels, ) +from jaxley.utils.plot_utils import plot_morph, plot_swc from jaxley.utils.swc import swc_to_jaxley @@ -23,7 +24,21 @@ class Cell(Module): cell_params: Dict = {} cell_states: Dict = {} - def __init__(self, branches: Union[Branch, List[Branch]], parents: List): + def __init__( + self, + branches: Union[Branch, List[Branch]], + parents: List, + xyzr: Optional[List[np.ndarray]] = None, + ): + """Initialize a cell. + + Args: + branches: + parents: + xyzr: For every branch, the x, y, and z coordinates and the radius at the + traced coordinates. Note that this is the full tracing (from SWC), not + the stick representation coordinates. + """ super().__init__() assert isinstance(branches, Branch) or len(parents) == len( branches @@ -33,6 +48,7 @@ def __init__(self, branches: Union[Branch, List[Branch]], parents: List): branch_list = [branches for _ in range(len(parents))] else: branch_list = branches + self.xyzr = xyzr self._append_to_params_and_state(branch_list) for branch in branch_list: @@ -214,6 +230,54 @@ def update_summed_coupling_conds( ) return summed_conds + def vis( + self, + detail: str = "full", + figsize=(4, 4), + dims=(0, 1), + cols="k", + highlight_branch_inds=[], + max_y_multiplier: float = 5.0, + min_y_multiplier: float = 0.5, + ) -> None: + """Visualize the network. + + Args: + detail: Either of [sticks, full]. `sticks` visualizes all branches of every + neuron, but draws branches as straight lines. `full` plots the full + morphology of every neuron, as read from the SWC file. + layers: Allows to plot the network in layers. Should provide the number of + neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input + neurons, 10 hidden layer neurons, and 1 output neuron. + options: Plotting options passed to `NetworkX.draw()`. + dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of + two of them. + cols: The color for all branches except the highlighted ones. + highlight_branch_inds: Branch indices that will be highlighted. + """ + if detail == "sticks": + fig, ax = plot_morph( + cell=self, + figsize=figsize, + cols=cols, + highlight_branch_inds=highlight_branch_inds, + max_y_multiplier=max_y_multiplier, + min_y_multiplier=min_y_multiplier, + ) + elif detail == "full": + assert self.xyzr is not None, "no coordinates, use `vis(detail='sticks')`." + fig, ax = plot_swc( + self.xyzr, + figsize=figsize, + dims=dims, + cols=cols, + highlight_branch_inds=highlight_branch_inds, + ) + else: + raise ValueError("`detail must be in {sticks, full}.") + + return fig, ax + class CellView(View): """CellView.""" @@ -239,7 +303,7 @@ def read_swc( min_radius: Optional[float] = None, ): """Reads SWC file into a `jx.Cell`.""" - parents, pathlengths, radius_fns, _ = swc_to_jaxley( + parents, pathlengths, radius_fns, _, coords_of_branches = swc_to_jaxley( fname, max_branch_len=max_branch_len, sort=True, num_lines=None ) nbranches = len(parents) @@ -249,7 +313,9 @@ def read_swc( comp = Compartment().initialize() branch = Branch([comp for _ in range(nseg)]).initialize() - cell = Cell([branch for _ in range(nbranches)], parents=parents) + cell = Cell( + [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches + ) radiuses = np.flip( np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1 diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index cf2a430e..d63e35df 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -6,7 +6,6 @@ _compute_num_children, compute_levels, ) -from jaxley.utils.swc import _build_parents, _split_into_branches_and_sort highlight_cols = [ "#1f78b4", @@ -114,8 +113,7 @@ def plot_morph( def plot_swc( - fname, - max_branch_len: float = 100.0, + xyzr, figsize=(4, 4), dims=(0, 1), cols="k", @@ -129,30 +127,21 @@ def plot_swc( cols: The color for all branches except the highlighted ones. highlight_branch_inds: Branch indices that will be highlighted. """ - content = np.loadtxt(fname) - sorted_branches, _ = _split_into_branches_and_sort( - content, max_branch_len=max_branch_len, sort=True - ) - parents = _build_parents(sorted_branches) - if np.sum(np.asarray(parents) == -1) > 1.0: - sorted_branches = [[0]] + sorted_branches - cols = [cols] * len(sorted_branches) + cols = [cols] * len(xyzr) counter_highlight_branches = 0 lines = [] fig, ax = plt.subplots(1, 1, figsize=figsize) - for i, branch in enumerate(sorted_branches): - coords_of_branch = content[np.asarray(branch) - 1, 2:5] - coords_of_branch = coords_of_branch[:, dims] - + for i, coords_of_branch in enumerate(xyzr): + coords_to_plot = coords_of_branch[:, dims] col = cols[i] if i in highlight_branch_inds: col = highlight_cols[counter_highlight_branches % len(highlight_cols)] counter_highlight_branches += 1 (line,) = ax.plot( - coords_of_branch[:, 0], coords_of_branch[:, 1], c=col, label=f"ind {i}" + coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, label=f"ind {i}" ) if i in highlight_branch_inds: lines.append(line) diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py index 2c73592a..98a05965 100644 --- a/jaxley/utils/swc.py +++ b/jaxley/utils/swc.py @@ -41,8 +41,14 @@ def swc_to_jaxley( parents = parents.tolist() pathlengths = [0.1] + pathlengths radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns + sorted_branches = [[0]] + sorted_branches - return parents, pathlengths, radius_fns, types + all_coords_of_branches = [] + for i, branch in enumerate(sorted_branches): + coords_of_branch = content[np.asarray(branch) - 1, 2:5] + all_coords_of_branches.append(coords_of_branch) + + return parents, pathlengths, radius_fns, types, all_coords_of_branches def _split_into_branches_and_sort(