diff --git a/src/scanpy/neighbors/_connectivity.py b/src/scanpy/neighbors/_connectivity.py index 06adf00374..a4b7ff3fc3 100644 --- a/src/scanpy/neighbors/_connectivity.py +++ b/src/scanpy/neighbors/_connectivity.py @@ -122,7 +122,7 @@ def umap( warnings.filterwarnings("ignore", message=r"Tensorflow not installed") from umap.umap_ import fuzzy_simplicial_set - X = coo_matrix(([], ([], [])), shape=(n_obs, 1)) + X = coo_matrix((n_obs, 1)) connectivities, _sigmas, _rhos = fuzzy_simplicial_set( X, n_neighbors, diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index e14d387f84..859a6f93cd 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -38,6 +38,11 @@ class VBoundNorm(NamedTuple): norm: Normalize | None +class VarGroups(NamedTuple): + labels: Sequence[str] + positions: Sequence[tuple[int, int]] + + doc_common_groupby_plot_args = """\ title Title for the figure @@ -87,6 +92,8 @@ class BasePlot: MAX_NUM_CATEGORIES = 500 # maximum number of categories allowed to be plotted + var_groups: VarGroups | None + @old_positionals( "use_raw", "log", @@ -129,18 +136,24 @@ def __init__( norm: Normalize | None = None, **kwds, ): - self.var_names = var_names - self.var_group_labels = var_group_labels - self.var_group_positions = var_group_positions + self.var_names, self.var_groups = _var_groups(var_names, ref=adata.var_names) + match (var_group_labels, var_group_positions, self.var_groups): + case (None, None, _): + pass # inferred from `var_names` + case (None, _, _) | (_, None, _): + msg = "both or none of var_group_labels and var_group_positions must be set" + raise TypeError(msg) + case (_, _, None): + if len(var_group_labels) != len(var_group_positions): + msg = "var_group_labels and var_group_positions must have the same length" + raise ValueError(msg) + self.var_groups = VarGroups(var_group_labels, var_group_positions) + case (_, _, _): + msg = "var_group_labels and var_group_positions cannot be set if var_names is a dict" + raise TypeError(msg) self.var_group_rotation = var_group_rotation self.width, self.height = figsize if figsize is not None else (None, None) - self.has_var_groups = ( - var_group_positions is not None and len(var_group_positions) > 0 - ) - - self._update_var_groups() - self.categories, self.obs_tidy = _prepare_dataframe( adata, self.var_names, @@ -702,7 +715,7 @@ def make_figure(self): width_ratios=[mainplot_width + self.group_extra_size, self.legends_width], ) - if self.has_var_groups: + if self.var_groups: # add some space in case 'brackets' want to be plotted on top of the image if self.are_axes_swapped: var_groups_height = category_height @@ -754,14 +767,14 @@ def make_figure(self): if self.plot_group_extra is not None: group_extra_ax = self.fig.add_subplot(mainplot_gs[2, 1], sharey=main_ax) group_extra_orientation = "right" - if self.has_var_groups: + if self.var_groups: gene_groups_ax = self.fig.add_subplot(mainplot_gs[1, 0], sharex=main_ax) var_group_orientation = "top" else: if self.plot_group_extra: group_extra_ax = self.fig.add_subplot(mainplot_gs[1, 0], sharex=main_ax) group_extra_orientation = "top" - if self.has_var_groups: + if self.var_groups: gene_groups_ax = self.fig.add_subplot(mainplot_gs[2, 1], sharey=main_ax) var_group_orientation = "right" @@ -781,11 +794,11 @@ def make_figure(self): return_ax_dict["group_extra_ax"] = group_extra_ax # plot group legends on top or left of main_ax (if given) - if self.has_var_groups: + if self.var_groups: self._plot_var_groups_brackets( gene_groups_ax, - group_positions=self.var_group_positions, - group_labels=self.var_group_labels, + group_positions=self.var_groups.positions, + group_labels=self.var_groups.labels, rotation=self.var_group_rotation, left_adjustment=0.2, right_adjustment=0.7, @@ -924,31 +937,30 @@ def _format_first_three_categories(_categories): if self.var_names is not None: var_names_idx_ordered = list(range(len(self.var_names))) - if self.has_var_groups: - if set(self.var_group_labels) == set(self.categories): + if self.var_groups: + if set(self.var_groups.labels) == set(self.categories): positions_ordered = [] labels_ordered = [] position_start = 0 var_names_idx_ordered = [] for cat_name in categories_ordered: - idx = self.var_group_labels.index(cat_name) - position = self.var_group_positions[idx] + idx = self.var_groups.labels.index(cat_name) + position = self.var_groups.positions[idx] _var_names = self.var_names[position[0] : position[1] + 1] var_names_idx_ordered.extend(range(position[0], position[1] + 1)) positions_ordered.append( (position_start, position_start + len(_var_names) - 1) ) position_start += len(_var_names) - labels_ordered.append(self.var_group_labels[idx]) - self.var_group_labels = labels_ordered - self.var_group_positions = positions_ordered + labels_ordered.append(self.var_groups.labels[idx]) + self.var_groups = VarGroups(labels_ordered, positions_ordered) else: logg.warning( "Groups are not reordered because the `groupby` categories " "and the `var_group_labels` are different.\n" f"categories: {_format_first_three_categories(self.categories)}\n" "var_group_labels: " - f"{_format_first_three_categories(self.var_group_labels)}" + f"{_format_first_three_categories(self.var_groups.labels)}" ) if var_names_idx_ordered is not None: @@ -1082,35 +1094,30 @@ def _plot_var_groups_brackets( axis="x", bottom=False, labelbottom=False, labeltop=False ) - def _update_var_groups(self) -> None: - """ - checks if var_names is a dict. Is this is the cases, then set the - correct values for var_group_labels and var_group_positions - updates var_names, var_group_labels, var_group_positions - """ - if isinstance(self.var_names, Mapping): - if self.has_var_groups: - logg.warning( - "`var_names` is a dictionary. This will reset the current " - "values of `var_group_labels` and `var_group_positions`." - ) - var_group_labels = [] - _var_names = [] - var_group_positions = [] - start = 0 - for label, vars_list in self.var_names.items(): - if isinstance(vars_list, str): - vars_list = [vars_list] - # use list() in case var_list is a numpy array or pandas series - _var_names.extend(list(vars_list)) - var_group_labels.append(label) - var_group_positions.append((start, start + len(vars_list) - 1)) - start += len(vars_list) - self.var_names = _var_names - self.var_group_labels = var_group_labels - self.var_group_positions = var_group_positions - self.has_var_groups = True - - elif isinstance(self.var_names, str): - self.var_names = [self.var_names] +def _var_groups( + var_names: _VarNames | Mapping[str, _VarNames], *, ref: pd.Index[str] +) -> tuple[Sequence[str], VarGroups | None]: + """ + Normalize var_names. + If it’s a mapping, also return var_group_labels and var_group_positions. + """ + + if not isinstance(var_names, Mapping): + var_names = [var_names] if isinstance(var_names, str) else var_names + return var_names, None + + var_group_labels: list[str] = [] + var_names_seq: list[str] = [] + var_group_positions: list[tuple[int, int]] = [] + for label, vars_list in var_names.items(): + vars_list = [vars_list] if isinstance(vars_list, str) else vars_list + start = len(var_names_seq) + # use list() in case var_list is a numpy array or pandas series + var_names_seq.extend(list(vars_list)) + var_group_labels.append(label) + var_group_positions.append((start, start + len(vars_list) - 1)) + if not var_names_seq: + msg = "No valid var_names were passed." + raise ValueError(msg) + return var_names_seq, VarGroups(var_group_labels, var_group_positions)