diff --git a/neurax/utils/plot_utils.py b/neurax/utils/plot_utils.py index 74c76915..058f0998 100644 --- a/neurax/utils/plot_utils.py +++ b/neurax/utils/plot_utils.py @@ -131,12 +131,11 @@ def plot_swc( """ content = np.loadtxt(fname) sorted_branches, _ = _split_into_branches_and_sort( - content, max_branch_len=max_branch_len + content, max_branch_len=max_branch_len, sort=True ) parents = _build_parents(sorted_branches) if np.sum(np.asarray(parents) == -1) > 1.0: sorted_branches = [[0]] + sorted_branches - cols = [cols] * len(sorted_branches) counter_highlight_branches = 0 @@ -158,6 +157,7 @@ def plot_swc( if i in highlight_branch_inds: lines.append(line) - ax.legend(handles=lines, loc="upper left", bbox_to_anchor=(1.05, 1, 0, 0)) + if highlight_branch_inds: + ax.legend(handles=lines, loc="upper left", bbox_to_anchor=(1.05, 1, 0, 0)) return fig, ax