Skip to content

Commit

Permalink
Simplify var_groups
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Feb 11, 2025
1 parent cce4bd6 commit d3f4dfc
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/scanpy/neighbors/_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def umap(
warnings.filterwarnings("ignore", message=r"Tensorflow not installed")
from umap.umap_ import fuzzy_simplicial_set

X = coo_matrix(([], ([], [])), shape=(n_obs, 1))
X = coo_matrix((n_obs, 1))
connectivities, _sigmas, _rhos = fuzzy_simplicial_set(
X,
n_neighbors,
Expand Down
115 changes: 61 additions & 54 deletions src/scanpy/plotting/_baseplot_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class VBoundNorm(NamedTuple):
norm: Normalize | None


class VarGroups(NamedTuple):
labels: Sequence[str]
positions: Sequence[tuple[int, int]]


doc_common_groupby_plot_args = """\
title
Title for the figure
Expand Down Expand Up @@ -87,6 +92,8 @@ class BasePlot:

MAX_NUM_CATEGORIES = 500 # maximum number of categories allowed to be plotted

var_groups: VarGroups | None

@old_positionals(
"use_raw",
"log",
Expand Down Expand Up @@ -129,18 +136,24 @@ def __init__(
norm: Normalize | None = None,
**kwds,
):
self.var_names = var_names
self.var_group_labels = var_group_labels
self.var_group_positions = var_group_positions
self.var_names, self.var_groups = _var_groups(var_names, ref=adata.var_names)
match (var_group_labels, var_group_positions, self.var_groups):
case (None, None, _):
pass # inferred from `var_names`
case (None, _, _) | (_, None, _):
msg = "both or none of var_group_labels and var_group_positions must be set"
raise TypeError(msg)

Check warning on line 145 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L144-L145

Added lines #L144 - L145 were not covered by tests
case (_, _, None):
if len(var_group_labels) != len(var_group_positions):
msg = "var_group_labels and var_group_positions must have the same length"
raise ValueError(msg)

Check warning on line 149 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L148-L149

Added lines #L148 - L149 were not covered by tests
self.var_groups = VarGroups(var_group_labels, var_group_positions)
case (_, _, _):
msg = "var_group_labels and var_group_positions cannot be set if var_names is a dict"
raise TypeError(msg)

Check warning on line 153 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L151-L153

Added lines #L151 - L153 were not covered by tests
self.var_group_rotation = var_group_rotation
self.width, self.height = figsize if figsize is not None else (None, None)

self.has_var_groups = (
var_group_positions is not None and len(var_group_positions) > 0
)

self._update_var_groups()

self.categories, self.obs_tidy = _prepare_dataframe(
adata,
self.var_names,
Expand Down Expand Up @@ -702,7 +715,7 @@ def make_figure(self):
width_ratios=[mainplot_width + self.group_extra_size, self.legends_width],
)

if self.has_var_groups:
if self.var_groups:
# add some space in case 'brackets' want to be plotted on top of the image
if self.are_axes_swapped:
var_groups_height = category_height
Expand Down Expand Up @@ -754,14 +767,14 @@ def make_figure(self):
if self.plot_group_extra is not None:
group_extra_ax = self.fig.add_subplot(mainplot_gs[2, 1], sharey=main_ax)
group_extra_orientation = "right"
if self.has_var_groups:
if self.var_groups:
gene_groups_ax = self.fig.add_subplot(mainplot_gs[1, 0], sharex=main_ax)
var_group_orientation = "top"
else:
if self.plot_group_extra:
group_extra_ax = self.fig.add_subplot(mainplot_gs[1, 0], sharex=main_ax)
group_extra_orientation = "top"
if self.has_var_groups:
if self.var_groups:
gene_groups_ax = self.fig.add_subplot(mainplot_gs[2, 1], sharey=main_ax)
var_group_orientation = "right"

Expand All @@ -781,11 +794,11 @@ def make_figure(self):
return_ax_dict["group_extra_ax"] = group_extra_ax

# plot group legends on top or left of main_ax (if given)
if self.has_var_groups:
if self.var_groups:
self._plot_var_groups_brackets(
gene_groups_ax,
group_positions=self.var_group_positions,
group_labels=self.var_group_labels,
group_positions=self.var_groups.positions,
group_labels=self.var_groups.labels,
rotation=self.var_group_rotation,
left_adjustment=0.2,
right_adjustment=0.7,
Expand Down Expand Up @@ -924,31 +937,30 @@ def _format_first_three_categories(_categories):
if self.var_names is not None:
var_names_idx_ordered = list(range(len(self.var_names)))

if self.has_var_groups:
if set(self.var_group_labels) == set(self.categories):
if self.var_groups:
if set(self.var_groups.labels) == set(self.categories):
positions_ordered = []
labels_ordered = []
position_start = 0
var_names_idx_ordered = []
for cat_name in categories_ordered:
idx = self.var_group_labels.index(cat_name)
position = self.var_group_positions[idx]
idx = self.var_groups.labels.index(cat_name)
position = self.var_groups.positions[idx]
_var_names = self.var_names[position[0] : position[1] + 1]
var_names_idx_ordered.extend(range(position[0], position[1] + 1))
positions_ordered.append(
(position_start, position_start + len(_var_names) - 1)
)
position_start += len(_var_names)
labels_ordered.append(self.var_group_labels[idx])
self.var_group_labels = labels_ordered
self.var_group_positions = positions_ordered
labels_ordered.append(self.var_groups.labels[idx])
self.var_groups = VarGroups(labels_ordered, positions_ordered)
else:
logg.warning(
"Groups are not reordered because the `groupby` categories "
"and the `var_group_labels` are different.\n"
f"categories: {_format_first_three_categories(self.categories)}\n"
"var_group_labels: "
f"{_format_first_three_categories(self.var_group_labels)}"
f"{_format_first_three_categories(self.var_groups.labels)}"
)

if var_names_idx_ordered is not None:
Expand Down Expand Up @@ -1082,35 +1094,30 @@ def _plot_var_groups_brackets(
axis="x", bottom=False, labelbottom=False, labeltop=False
)

def _update_var_groups(self) -> None:
"""
checks if var_names is a dict. Is this is the cases, then set the
correct values for var_group_labels and var_group_positions

updates var_names, var_group_labels, var_group_positions
"""
if isinstance(self.var_names, Mapping):
if self.has_var_groups:
logg.warning(
"`var_names` is a dictionary. This will reset the current "
"values of `var_group_labels` and `var_group_positions`."
)
var_group_labels = []
_var_names = []
var_group_positions = []
start = 0
for label, vars_list in self.var_names.items():
if isinstance(vars_list, str):
vars_list = [vars_list]
# use list() in case var_list is a numpy array or pandas series
_var_names.extend(list(vars_list))
var_group_labels.append(label)
var_group_positions.append((start, start + len(vars_list) - 1))
start += len(vars_list)
self.var_names = _var_names
self.var_group_labels = var_group_labels
self.var_group_positions = var_group_positions
self.has_var_groups = True

elif isinstance(self.var_names, str):
self.var_names = [self.var_names]
def _var_groups(
var_names: _VarNames | Mapping[str, _VarNames], *, ref: pd.Index[str]
) -> tuple[Sequence[str], VarGroups | None]:
"""
Normalize var_names.
If it’s a mapping, also return var_group_labels and var_group_positions.
"""

if not isinstance(var_names, Mapping):
var_names = [var_names] if isinstance(var_names, str) else var_names
return var_names, None

var_group_labels: list[str] = []
var_names_seq: list[str] = []
var_group_positions: list[tuple[int, int]] = []
for label, vars_list in var_names.items():
vars_list = [vars_list] if isinstance(vars_list, str) else vars_list
start = len(var_names_seq)
# use list() in case var_list is a numpy array or pandas series
var_names_seq.extend(list(vars_list))
var_group_labels.append(label)
var_group_positions.append((start, start + len(vars_list) - 1))
if not var_names_seq:
msg = "No valid var_names were passed."
raise ValueError(msg)

Check warning on line 1122 in src/scanpy/plotting/_baseplot_class.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/plotting/_baseplot_class.py#L1121-L1122

Added lines #L1121 - L1122 were not covered by tests
return var_names_seq, VarGroups(var_group_labels, var_group_positions)

0 comments on commit d3f4dfc

Please sign in to comment.