From 7e3dd157d109164d5f7e1bc82926d57c70c34572 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 17 Dec 2024 11:07:59 +0100 Subject: [PATCH 01/24] Modify error message if certifi is not installed (#3402) --- src/scanpy/_compat.py | 13 +++++++++++++ src/scanpy/readwrite.py | 23 +++++++++++------------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index b97b1a8603..d2c69a9e37 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -106,6 +106,19 @@ def old_positionals(*old_positionals: str): return lambda func: func +if sys.version_info >= (3, 11): + + @wraps(BaseException.add_note) + def add_note(exc: BaseException, note: str) -> None: + exc.add_note(note) +else: + + def add_note(exc: BaseException, note: str) -> None: + if not hasattr(exc, "__notes__"): + exc.__notes__ = [] + exc.__notes__.append(note) + + if sys.version_info >= (3, 13): from warnings import deprecated as _deprecated else: diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 3c958a1e50..07bd817ca5 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -36,7 +36,7 @@ from matplotlib.image import imread from . import logging as logg -from ._compat import old_positionals +from ._compat import add_note, old_positionals from ._settings import settings from ._utils import _empty @@ -993,15 +993,11 @@ def _get_filename_from_key(key, ext=None) -> Path: def _download(url: str, path: Path): - try: - import ipywidgets # noqa: F401 - from tqdm.auto import tqdm - except ImportError: - from tqdm import tqdm - from urllib.error import URLError from urllib.request import Request, urlopen + from tqdm.auto import tqdm + blocksize = 1024 * 8 blocknum = 0 @@ -1011,14 +1007,17 @@ def _download(url: str, path: Path): try: open_url = urlopen(req) except URLError: - logg.warning( - "Failed to open the url with default certificates, trying with certifi." - ) + msg = "Failed to open the url with default certificates." + try: + from certifi import where + except ImportError as e: + add_note(e, f"{msg} Please install `certifi` and try again.") + raise + else: + logg.warning(f"{msg} Trying to use certifi.") from ssl import create_default_context - from certifi import where - open_url = urlopen(req, context=create_default_context(cafile=where())) with open_url as resp: From 86d656daa5553aa39804e21f8daab735cddbf6c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6k=C3=A7en=20Eraslan?= Date: Thu, 19 Dec 2024 08:48:10 -0800 Subject: [PATCH 02/24] Add replace option to subsample and rename function to sample (#943) --- docs/api/deprecated.md | 1 + docs/api/preprocessing.md | 2 +- docs/release-notes/943.feature.md | 1 + pyproject.toml | 4 +- src/scanpy/_compat.py | 41 ++++- src/scanpy/preprocessing/__init__.py | 4 +- .../preprocessing/_deprecated/sampling.py | 60 +++++++ src/scanpy/preprocessing/_simple.py | 165 +++++++++++++----- tests/test_package_structure.py | 1 + tests/test_preprocessing.py | 144 ++++++++++++--- tests/test_utils.py | 42 ++++- 11 files changed, 391 insertions(+), 74 deletions(-) create mode 100644 docs/release-notes/943.feature.md create mode 100644 src/scanpy/preprocessing/_deprecated/sampling.py diff --git a/docs/api/deprecated.md b/docs/api/deprecated.md index 4511f4b3a7..d09c1af405 100644 --- a/docs/api/deprecated.md +++ b/docs/api/deprecated.md @@ -11,4 +11,5 @@ pp.filter_genes_dispersion pp.normalize_per_cell + pp.subsample ``` diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 4b17567a6b..36e732a6dc 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -31,7 +31,7 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and pp.normalize_total pp.regress_out pp.scale - pp.subsample + pp.sample pp.downsample_counts ``` diff --git a/docs/release-notes/943.feature.md b/docs/release-notes/943.feature.md new file mode 100644 index 0000000000..4f5474d762 --- /dev/null +++ b/docs/release-notes/943.feature.md @@ -0,0 +1 @@ +{func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan` & {smaller}`P Angerer` diff --git a/pyproject.toml b/pyproject.toml index f1495442fe..b4b8abd1b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ classifiers = [ ] dependencies = [ "anndata>=0.8", - "numpy>=1.23", + "numpy>=1.24", "matplotlib>=3.6", "pandas >=1.5", "scipy>=1.8", @@ -60,7 +60,7 @@ dependencies = [ "networkx>=2.7", "natsort", "joblib", - "numba>=0.56", + "numba>=0.57", "umap-learn>=0.5,!=0.5.0", "pynndescent>=0.5", "packaging>=21.3", diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index d2c69a9e37..9ea7780b0d 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -4,7 +4,7 @@ import sys import warnings from dataclasses import dataclass, field -from functools import cache, partial, wraps +from functools import WRAPPER_ASSIGNMENTS, cache, partial, wraps from importlib.util import find_spec from pathlib import Path from typing import TYPE_CHECKING, Literal, ParamSpec, TypeVar, cast, overload @@ -224,3 +224,42 @@ def _numba_threading_layer() -> Layer: f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" ) raise ValueError(msg) + + +def _legacy_numpy_gen( + random_state: _LegacyRandom | None = None, +) -> np.random.Generator: + """Return a random generator that behaves like the legacy one.""" + + if random_state is not None: + if isinstance(random_state, np.random.RandomState): + np.random.set_state(random_state.get_state(legacy=False)) + return _FakeRandomGen(random_state) + np.random.seed(random_state) + return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) + + +class _FakeRandomGen(np.random.Generator): + _state: np.random.RandomState + + def __init__(self, random_state: np.random.RandomState) -> None: + self._state = random_state + + @classmethod + def _delegate(cls) -> None: + for name, meth in np.random.Generator.__dict__.items(): + if name.startswith("_") or not callable(meth): + continue + + def mk_wrapper(name: str): + # Old pytest versions try to run the doctests + @wraps(meth, assigned=set(WRAPPER_ASSIGNMENTS) - {"__doc__"}) + def wrapper(self: _FakeRandomGen, *args, **kwargs): + return getattr(self._state, name)(*args, **kwargs) + + return wrapper + + setattr(cls, name, mk_wrapper(name)) + + +_FakeRandomGen._delegate() diff --git a/src/scanpy/preprocessing/__init__.py b/src/scanpy/preprocessing/__init__.py index 8c396d8640..4307cbb6c9 100644 --- a/src/scanpy/preprocessing/__init__.py +++ b/src/scanpy/preprocessing/__init__.py @@ -3,6 +3,7 @@ from ..neighbors import neighbors from ._combat import combat from ._deprecated.highly_variable_genes import filter_genes_dispersion +from ._deprecated.sampling import subsample from ._highly_variable_genes import highly_variable_genes from ._normalization import normalize_total from ._pca import pca @@ -17,8 +18,8 @@ log1p, normalize_per_cell, regress_out, + sample, sqrt, - subsample, ) __all__ = [ @@ -40,6 +41,7 @@ "log1p", "normalize_per_cell", "regress_out", + "sample", "scale", "sqrt", "subsample", diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py new file mode 100644 index 0000000000..02619a2364 --- /dev/null +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..._compat import _legacy_numpy_gen, old_positionals +from .._simple import sample + +if TYPE_CHECKING: + import numpy as np + from anndata import AnnData + from numpy.typing import NDArray + from scipy.sparse import csc_matrix, csr_matrix + + from ..._compat import _LegacyRandom + + CSMatrix = csr_matrix | csc_matrix + + +@old_positionals("n_obs", "random_state", "copy") +def subsample( + data: AnnData | np.ndarray | CSMatrix, + fraction: float | None = None, + *, + n_obs: int | None = None, + random_state: _LegacyRandom = 0, + copy: bool = False, +) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None: + """\ + Subsample to a fraction of the number of observations. + + .. deprecated:: 1.11.0 + + Use :func:`~scanpy.pp.sample` instead. + + Parameters + ---------- + data + The (annotated) data matrix of shape `n_obs` × `n_vars`. + Rows correspond to cells and columns to genes. + fraction + Subsample to this `fraction` of the number of observations. + n_obs + Subsample to this number of observations. + random_state + Random seed to change subsampling. + copy + If an :class:`~anndata.AnnData` is passed, + determines whether a copy is returned. + + Returns + ------- + Returns `X[obs_indices], obs_indices` if data is array-like, otherwise + subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or + returns a subsampled copy of it (`copy == True`). + """ + + rng = _legacy_numpy_gen(random_state) + return sample( + data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0 + ) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index eaf9648690..29c267c3f4 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -8,20 +8,21 @@ import warnings from functools import singledispatch from itertools import repeat -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, TypeVar, overload import numba import numpy as np from anndata import AnnData from pandas.api.types import CategoricalDtype -from scipy.sparse import csr_matrix, issparse, isspmatrix_csr, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix from sklearn.utils import check_array, sparsefuncs from .. import logging as logg -from .._compat import deprecated, njit, old_positionals +from .._compat import DaskArray, deprecated, njit, old_positionals from .._settings import settings as sett from .._utils import ( _check_array_function_arguments, + _resolve_axis, axis_sum, is_backed_type, raise_not_implemented_error_if_backed_type, @@ -33,15 +34,11 @@ from ._distributed import materialize_as_ndarray from ._utils import _to_dense -# install dask if available try: import dask.array as da except ImportError: da = None -# backwards compat -from ._deprecated.highly_variable_genes import filter_genes_dispersion # noqa: F401 - if TYPE_CHECKING: from collections.abc import Collection, Iterable, Sequence from numbers import Number @@ -50,7 +47,13 @@ import pandas as pd from numpy.typing import NDArray - from .._compat import DaskArray, _LegacyRandom + from .._compat import _LegacyRandom + from .._utils import RNGLike, SeedLike + + +CSMatrix = csr_matrix | csc_matrix + +A = TypeVar("A", bound=np.ndarray | CSMatrix | DaskArray) @old_positionals( @@ -825,17 +828,51 @@ def _regress_out_chunk( return np.vstack(responses_chunk_list) -@old_positionals("n_obs", "random_state", "copy") -def subsample( - data: AnnData | np.ndarray | spmatrix, +@overload +def sample( + data: AnnData, fraction: float | None = None, *, - n_obs: int | None = None, - random_state: _LegacyRandom = 0, + n: int | None = None, + rng: RNGLike | SeedLike | None = 0, + copy: Literal[False] = False, + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> None: ... +@overload +def sample( + data: AnnData, + fraction: float | None = None, + *, + n: int | None = None, + rng: RNGLike | SeedLike | None = None, + copy: Literal[True], + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> AnnData: ... +@overload +def sample( + data: A, + fraction: float | None = None, + *, + n: int | None = None, + rng: RNGLike | SeedLike | None = None, copy: bool = False, -) -> AnnData | tuple[np.ndarray | spmatrix, NDArray[np.int64]] | None: + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> tuple[A, NDArray[np.int64]]: ... +def sample( + data: AnnData | np.ndarray | CSMatrix | DaskArray, + fraction: float | None = None, + *, + n: int | None = None, + rng: RNGLike | SeedLike | None = None, + copy: bool = False, + replace: bool = False, + axis: Literal["obs", 0, "var", 1] = "obs", +) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]: """\ - Subsample to a fraction of the number of observations. + Sample observations or variables with or without replacement. Parameters ---------- @@ -843,49 +880,81 @@ def subsample( The (annotated) data matrix of shape `n_obs` × `n_vars`. Rows correspond to cells and columns to genes. fraction - Subsample to this `fraction` of the number of observations. - n_obs - Subsample to this number of observations. + Sample to this `fraction` of the number of observations or variables. + This can be larger than 1.0, if `replace=True`. + See `axis` and `replace`. + n + Sample to this number of observations or variables. See `axis`. random_state Random seed to change subsampling. copy If an :class:`~anndata.AnnData` is passed, determines whether a copy is returned. + replace + If True, samples are drawn with replacement. + axis + Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1). Returns ------- - Returns `X[obs_indices], obs_indices` if data is array-like, otherwise - subsamples the passed :class:`~anndata.AnnData` (`copy == False`) or - returns a subsampled copy of it (`copy == True`). + If `isinstance(data, AnnData)` and `copy=False`, + this function returns `None`. Otherwise: + + `data[indices, :]` | `data[:, indices]` (depending on `axis`) + If `data` is array-like or `copy=True`, returns the subset. + `indices` : numpy.ndarray + If `data` is array-like, also returns the indices into the original. """ - np.random.seed(random_state) - old_n_obs = data.n_obs if isinstance(data, AnnData) else data.shape[0] - if n_obs is not None: - new_n_obs = n_obs - elif fraction is not None: - if fraction > 1 or fraction < 0: - raise ValueError(f"`fraction` needs to be within [0, 1], not {fraction}") - new_n_obs = int(fraction * old_n_obs) - logg.debug(f"... subsampled to {new_n_obs} data points") - else: - raise ValueError("Either pass `n_obs` or `fraction`.") - obs_indices = np.random.choice(old_n_obs, size=new_n_obs, replace=False) - if isinstance(data, AnnData): - if data.isbacked: - if copy: - return data[obs_indices].to_memory() - else: - raise NotImplementedError( - "Inplace subsampling is not implemented for backed objects." - ) + # parameter validation + if not copy and isinstance(data, AnnData) and data.isbacked: + msg = "Inplace sampling (`copy=False`) is not implemented for backed objects." + raise NotImplementedError(msg) + axis, axis_name = _resolve_axis(axis) + old_n = data.shape[axis] + match (fraction, n): + case (None, None): + msg = "Either `fraction` or `n` must be set." + raise TypeError(msg) + case (None, _): + pass + case (_, None): + if fraction < 0: + msg = f"`{fraction=}` needs to be nonnegative." + raise ValueError(msg) + if not replace and fraction > 1: + msg = f"If `replace=False`, `{fraction=}` needs to be within [0, 1]." + raise ValueError(msg) + n = int(fraction * old_n) + logg.debug(f"... sampled to {n} {axis_name}") + case _: + msg = "Providing both `fraction` and `n` is not allowed." + raise TypeError(msg) + del fraction + + # actually do subsampling + rng = np.random.default_rng(rng) + indices = rng.choice(old_n, size=n, replace=replace) + + # overload 1: inplace AnnData subset + if not copy and isinstance(data, AnnData): + if axis_name == "obs": + data._inplace_subset_obs(indices) else: - if copy: - return data[obs_indices].copy() - else: - data._inplace_subset_obs(obs_indices) - else: - X = data - return X[obs_indices], obs_indices + data._inplace_subset_var(indices) + return None + + subset = data[indices] if axis_name == "obs" else data[:, indices] + + # overload 2: copy AnnData subset + if copy and isinstance(data, AnnData): + assert isinstance(subset, AnnData) + return subset.to_memory() if data.isbacked else subset.copy() + + # overload 3: return array and indices + assert isinstance(subset, np.ndarray | CSMatrix | DaskArray), type(subset) + if copy: + subset = subset.copy() + return subset, indices @renamed_arg("target_counts", "counts_per_cell") diff --git a/tests/test_package_structure.py b/tests/test_package_structure.py index 834c06d8b4..3541c561a5 100644 --- a/tests/test_package_structure.py +++ b/tests/test_package_structure.py @@ -138,6 +138,7 @@ class ExpectedSig(TypedDict): copy_sigs["sc.pp.filter_cells"] = None # unclear `inplace` situation copy_sigs["sc.pp.filter_genes"] = None # unclear `inplace` situation copy_sigs["sc.pp.subsample"] = None # returns indices along matrix +copy_sigs["sc.pp.sample"] = None # returns indices along matrix # partial exceptions: “data” instead of “adata” copy_sigs["sc.pp.log1p"]["first_name"] = "data" copy_sigs["sc.pp.normalize_per_cell"]["first_name"] = "data" diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index b8f5115b01..36283e7ed0 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,7 +1,10 @@ from __future__ import annotations +import warnings +from importlib.util import find_spec from itertools import product from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -22,6 +25,13 @@ from testing.scanpy._helpers.data import pbmc3k, pbmc68k_reduced from testing.scanpy._pytest.params import ARRAY_TYPES +if TYPE_CHECKING: + from collections.abc import Callable + from typing import Any, Literal + + CSMatrix = sp.csc_matrix | sp.csr_matrix + + HERE = Path(__file__).parent DATA_PATH = HERE / "_data" @@ -134,34 +144,128 @@ def test_normalize_per_cell(): assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist() -def test_subsample(): - adata = AnnData(np.ones((200, 10))) - sc.pp.subsample(adata, n_obs=40) - assert adata.n_obs == 40 - sc.pp.subsample(adata, fraction=0.1) - assert adata.n_obs == 4 +@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize("which", ["copy", "inplace", "array"]) +@pytest.mark.parametrize( + ("axis", "fraction", "n", "replace", "expected"), + [ + pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"), + pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"), + pytest.param(0, None, 201, True, 201, id="obs-201-replace"), + pytest.param(0, None, 1, True, 1, id="obs-1-replace"), + pytest.param(1, None, 10, False, 10, id="var-10-no_replace"), + pytest.param(1, None, 11, True, 11, id="var-11-replace"), + pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"), + ], +) +def test_sample( + *, + array_type: Callable[[np.ndarray], np.ndarray | CSMatrix], + which: Literal["copy", "inplace", "array"], + axis: Literal[0, 1], + fraction: float | None, + n: int | None, + replace: bool, + expected: int, +): + adata = AnnData(array_type(np.ones((200, 10)))) + + # ignoring this warning declaratively is a pain so do it here + if find_spec("dask"): + import dask.array as da + + warnings.filterwarnings("ignore", category=da.PerformanceWarning) + # can’t guarantee that duplicates are drawn when `replace=True`, + # so we just ignore the warning instead using `with pytest.warns(...)` + warnings.filterwarnings( + "ignore" if replace else "error", r".*names are not unique", UserWarning + ) + rv = sc.pp.sample( + adata.X if which == "array" else adata, + fraction, + n=n, + replace=replace, + axis=axis, + # `copy` only effects AnnData inputs + copy=dict(copy=True, inplace=False, array=False)[which], + ) + match which: + case "copy": + subset = rv + assert rv is not adata + assert adata.shape == (200, 10) + case "inplace": + subset = adata + assert rv is None + case "array": + subset, indices = rv + assert len(indices) == expected + assert adata.shape == (200, 10) + case _: + pytest.fail(f"Unknown `{which=}`") -def test_subsample_copy(): + assert subset.shape == ((expected, 10) if axis == 0 else (200, expected)) + + +@pytest.mark.parametrize( + ("args", "exc", "pattern"), + [ + pytest.param( + dict(), TypeError, r"Either `fraction` or `n` must be set", id="empty" + ), + pytest.param( + dict(n=10, fraction=0.2), + TypeError, + r"Providing both `fraction` and `n` is not allowed", + id="both", + ), + pytest.param( + dict(fraction=2), + ValueError, + r"If `replace=False`, `fraction=2` needs to be", + id="frac>1", + ), + pytest.param( + dict(fraction=-0.3), + ValueError, + r"`fraction=-0\.3` needs to be nonnegative", + id="frac<0", + ), + ], +) +def test_sample_error(args: dict[str, Any], exc: type[Exception], pattern: str): adata = AnnData(np.ones((200, 10))) - assert sc.pp.subsample(adata, n_obs=40, copy=True).shape == (40, 10) - assert sc.pp.subsample(adata, fraction=0.1, copy=True).shape == (20, 10) + with pytest.raises(exc, match=pattern): + sc.pp.sample(adata, **args) -def test_subsample_copy_backed(tmp_path): - A = np.random.rand(200, 10).astype(np.float32) - adata_m = AnnData(A.copy()) - adata_d = AnnData(A.copy()) - filename = tmp_path / "test.h5ad" - adata_d.filename = filename - # This should not throw an error - assert sc.pp.subsample(adata_d, n_obs=40, copy=True).shape == (40, 10) +def test_sample_backwards_compat(): + expected = np.array( + [26, 86, 2, 55, 75, 93, 16, 73, 54, 95, 53, 92, 78, 13, 7, 30, 22, 24, 33, 8] + ) + legacy_result, indices = sc.pp.subsample(np.arange(100), n_obs=20) + assert np.array_equal(indices, legacy_result), "arange choices should match indices" + assert np.array_equal(legacy_result, expected) + + +def test_sample_copy_backed(tmp_path): + adata_m = AnnData(np.random.rand(200, 10).astype(np.float32)) + adata_d = adata_m.copy() + adata_d.filename = tmp_path / "test.h5ad" + + assert sc.pp.sample(adata_d, n=40, copy=True).shape == (40, 10) np.testing.assert_array_equal( - sc.pp.subsample(adata_m, n_obs=40, copy=True).X, - sc.pp.subsample(adata_d, n_obs=40, copy=True).X, + sc.pp.sample(adata_m, n=40, copy=True, rng=0).X, + sc.pp.sample(adata_d, n=40, copy=True, rng=0).X, ) + + +def test_sample_copy_backed_error(tmp_path): + adata_d = AnnData(np.random.rand(200, 10).astype(np.float32)) + adata_d.filename = tmp_path / "test.h5ad" with pytest.raises(NotImplementedError): - sc.pp.subsample(adata_d, n_obs=40, copy=False) + sc.pp.sample(adata_d, n=40, copy=False) @pytest.mark.parametrize("array_type", ARRAY_TYPES) diff --git a/tests/test_utils.py b/tests/test_utils.py index f8a38a5f9d..81369a6938 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ from operator import mul, truediv from types import ModuleType +from typing import TYPE_CHECKING import numpy as np import pytest @@ -9,7 +10,7 @@ from packaging.version import Version from scipy.sparse import csr_matrix, issparse -from scanpy._compat import DaskArray, pkg_version +from scanpy._compat import DaskArray, _legacy_numpy_gen, pkg_version from scanpy._utils import ( axis_mul_or_truediv, axis_sum, @@ -26,6 +27,9 @@ ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED, ) +if TYPE_CHECKING: + from typing import Any + def test_descend_classes_and_funcs(): # create module hierarchy @@ -247,3 +251,39 @@ def test_is_constant_dask(request: pytest.FixtureRequest, axis, expected, block_ x = da.from_array(np.array(x_data), chunks=2).map_blocks(block_type) result = is_constant(x, axis=axis).compute() np.testing.assert_array_equal(expected, result) + + +@pytest.mark.parametrize("seed", [0, 1, 1256712675]) +@pytest.mark.parametrize("pass_seed", [True, False], ids=["pass_seed", "set_seed"]) +@pytest.mark.parametrize("func", ["choice"]) +def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): + np.random.seed(seed) + state_before = np.random.get_state(legacy=False) + + arrs: dict[bool, np.ndarray] = {} + states_after: dict[bool, dict[str, Any]] = {} + for direct in [True, False]: + if not pass_seed: + np.random.seed(seed) + arrs[direct] = _mk_random(func, direct=direct, seed=seed if pass_seed else None) + states_after[direct] = np.random.get_state(legacy=False) + + np.testing.assert_array_equal(arrs[True], arrs[False]) + np.testing.assert_equal( + *states_after.values(), err_msg="both should affect global state the same" + ) + # they should affect the global state + with pytest.raises(AssertionError): + np.testing.assert_equal(states_after[True], state_before) + + +def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: + if direct and seed is not None: + np.random.seed(seed) + gen = np.random if direct else _legacy_numpy_gen(seed) + match func: + case "choice": + arr = np.arange(1000) + return gen.choice(arr, size=(100, 100)) + case _: + pytest.fail(f"Unknown {func=}") From e3efba280eed0726eb3e397715069cbaec761ba4 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 19 Dec 2024 17:49:55 +0100 Subject: [PATCH 03/24] Switch to session-info2 (#3384) --- docs/conf.py | 1 + docs/release-notes/3384.feature.md | 1 + pyproject.toml | 2 +- src/scanpy/logging.py | 92 ++++++++---------------------- tests/test_logging.py | 4 +- 5 files changed, 31 insertions(+), 69 deletions(-) create mode 100644 docs/release-notes/3384.feature.md diff --git a/docs/conf.py b/docs/conf.py index 2c79aa8d82..155869b360 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -142,6 +142,7 @@ rapids_singlecell=("https://rapids-singlecell.readthedocs.io/en/latest/", None), scipy=("https://docs.scipy.org/doc/scipy/", None), seaborn=("https://seaborn.pydata.org/", None), + session_info2=("https://session-info2.readthedocs.io/en/stable/", None), sklearn=("https://scikit-learn.org/stable/", None), ) diff --git a/docs/release-notes/3384.feature.md b/docs/release-notes/3384.feature.md new file mode 100644 index 0000000000..755af9a8a3 --- /dev/null +++ b/docs/release-notes/3384.feature.md @@ -0,0 +1 @@ +Switch {func}`~scanpy.logging.print_header` and {func}`~scanpy.logging.print_versions` to {mod}`session_info2` {smaller}`P Angerer` diff --git a/pyproject.toml b/pyproject.toml index b4b8abd1b1..8e23afb14b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dependencies = [ "umap-learn>=0.5,!=0.5.0", "pynndescent>=0.5", "packaging>=21.3", - "session-info", + "session-info2", "legacy-api-wrap>=1.4", # for positional API deprecations "typing-extensions; python_version < '3.13'", ] diff --git a/src/scanpy/logging.py b/src/scanpy/logging.py index 168c3b5405..3aa0ca494c 100644 --- a/src/scanpy/logging.py +++ b/src/scanpy/logging.py @@ -4,17 +4,20 @@ import logging import sys -import warnings from datetime import datetime, timedelta, timezone from functools import partial, update_wrapper from logging import CRITICAL, DEBUG, ERROR, INFO, WARNING -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload import anndata.logging +from ._compat import deprecated + if TYPE_CHECKING: from typing import IO + from session_info2 import SessionInfo + from ._settings import ScanpyConfig @@ -127,33 +130,11 @@ def format(self, record: logging.LogRecord): get_memory_usage = anndata.logging.get_memory_usage -_DEPENDENCIES_NUMERICS = [ - "anndata", # anndata actually shouldn't, but as long as it's in development - "umap", - "numpy", - "scipy", - "pandas", - ("sklearn", "scikit-learn"), - "statsmodels", - "igraph", - "louvain", - "leidenalg", - "pynndescent", -] - - -def _versions_dependencies(dependencies): - # this is not the same as the requirements! - for mod in dependencies: - mod_name, dist_name = mod if isinstance(mod, tuple) else (mod, mod) - try: - imp = __import__(mod_name) - yield dist_name, imp.__version__ - except (ImportError, AttributeError): - pass - - -def print_header(*, file=None): +@overload +def print_header(*, file: None = None) -> SessionInfo: ... +@overload +def print_header(*, file: IO[str]) -> None: ... +def print_header(*, file: IO[str] | None = None): """\ Versions that might influence the numerical results. Matplotlib and Seaborn are excluded from this. @@ -163,50 +144,27 @@ def print_header(*, file=None): file Optional path for dependency output. """ + from session_info2 import session_info - modules = ["scanpy"] + _DEPENDENCIES_NUMERICS - print( - " ".join(f"{mod}=={ver}" for mod, ver in _versions_dependencies(modules)), - file=file or sys.stdout, - ) + sinfo = session_info(os=True, cpu=True, gpu=True, dependencies=True) + + if file is not None: + print(sinfo, file=file) + return + + return sinfo -def print_versions(*, file: IO[str] | None = None): +@deprecated("Use `print_header` instead") +def print_versions() -> SessionInfo: """\ - Print versions of imported packages, OS, and jupyter environment. + Alias for `print_header`. - For more options (including rich output) use `session_info.show` directly. + .. deprecated:: 1.11.0 - Parameters - ---------- - file - Optional path for output. + Use :func:`print_header` instead. """ - import session_info - - if file is not None: - from contextlib import redirect_stdout - - warnings.warn( - "Passing argument 'file' to print_versions is deprecated, and will be " - "removed in a future version.", - FutureWarning, - ) - with redirect_stdout(file): - print_versions() - else: - session_info.show( - dependencies=True, - html=False, - excludes=[ - "builtins", - "stdlib_list", - "importlib_metadata", - # Special module present if test coverage being calculated - # https://gitlab.com/joelostblom/session_info/-/issues/10 - "$coverage", - ], - ) + return print_header() def print_version_and_date(*, file=None): @@ -235,7 +193,7 @@ def _copy_docs_and_signature(fn): def error( msg: str, *, - time: datetime = None, + time: datetime | None = None, deep: str | None = None, extra: dict | None = None, ) -> datetime: diff --git a/tests/test_logging.py b/tests/test_logging.py index 3f8a3ee97d..81b4acbf38 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -142,6 +142,8 @@ def test_call_outputs(func): """ output_io = StringIO() with redirect_stdout(output_io): - func() + out = func() + if out is not None: + print(out) output = output_io.getvalue() assert output != "" From 1cd5a00d750d061b04b94a9a6bf780a161f1da9e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 19 Dec 2024 19:30:49 +0100 Subject: [PATCH 04/24] Scipy 1.15 compat, some test refactors (#3409) --- src/scanpy/tools/_rank_genes_groups.py | 2 +- tests/test_backed.py | 4 +- tests/test_filter_rank_genes_groups.py | 211 +++++++++--------------- tests/test_preprocessing_distributed.py | 3 +- 4 files changed, 79 insertions(+), 141 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 59526ee516..aa4428dad1 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -854,7 +854,7 @@ def filter_rank_genes_groups( if not use_logfolds or not use_fraction: sub_X = adata.raw[:, var_names].X if use_raw else adata[:, var_names].X - in_group = adata.obs[groupby] == cluster + in_group = (adata.obs[groupby] == cluster).to_numpy() X_in = sub_X[in_group] X_out = sub_X[~in_group] diff --git a/tests/test_backed.py b/tests/test_backed.py index 787edf9c21..bfa1d79592 100644 --- a/tests/test_backed.py +++ b/tests/test_backed.py @@ -91,8 +91,8 @@ def test_log1p_backed_errors(backed_adata): def test_scatter_backed(backed_adata): sc.pp.pca(backed_adata, chunked=True) - sc.pl.scatter(backed_adata, color="0", basis="pca") + sc.pl.scatter(backed_adata, color="0", basis="pca", show=False) def test_dotplot_backed(backed_adata): - sc.pl.dotplot(backed_adata, ["0", "1", "2", "3"], groupby="cat") + sc.pl.dotplot(backed_adata, ["0", "1", "2", "3"], groupby="cat", show=False) diff --git a/tests/test_filter_rank_genes_groups.py b/tests/test_filter_rank_genes_groups.py index 26851bb102..a64ac983f3 100644 --- a/tests/test_filter_rank_genes_groups.py +++ b/tests/test_filter_rank_genes_groups.py @@ -1,159 +1,96 @@ from __future__ import annotations import numpy as np +import pytest from scanpy.tools import filter_rank_genes_groups, rank_genes_groups from testing.scanpy._helpers.data import pbmc68k_reduced -names_no_reference = np.array( +NAMES_NO_REF = [ + ["CD3D", "ITM2A", "CD3D", "CCL5", "CD7", "nan", "CD79A", "nan", "NKG7", "LYZ"], + ["CD3E", "CD3D", "nan", "NKG7", "CD3D", "AIF1", "CD79B", "nan", "GNLY", "CST3"], + ["IL32", "RPL39", "nan", "CST7", "nan", "nan", "nan", "SNHG7", "CD7", "nan"], + ["nan", "SRSF7", "IL32", "GZMA", "nan", "LST1", "IGJ", "nan", "CTSW", "nan"], + ["nan", "nan", "CD2", "CTSW", "CD8B", "TYROBP", "ISG20", "SNHG8", "GZMB", "nan"], +] + +NAMES_REF = [ + ["CD3D", "ITM2A", "CD3D", "nan", "CD3D", "nan", "CD79A", "nan", "CD7"], + ["nan", "nan", "nan", "CD3D", "nan", "AIF1", "nan", "nan", "NKG7"], + ["nan", "nan", "nan", "NKG7", "nan", "FCGR3A", "ISG20", "SNHG7", "CTSW"], + ["nan", "CD3D", "nan", "CCL5", "CD7", "nan", "CD79B", "nan", "GNLY"], + ["CD3E", "IL32", "nan", "IL32", "CD27", "FCER1G", "nan", "nan", "nan"], +] + +NAMES_NO_REF_COMPARE_ABS = [ [ - ["CD3D", "ITM2A", "CD3D", "CCL5", "CD7", "nan", "CD79A", "nan", "NKG7", "LYZ"], - ["CD3E", "CD3D", "nan", "NKG7", "CD3D", "AIF1", "CD79B", "nan", "GNLY", "CST3"], - ["IL32", "RPL39", "nan", "CST7", "nan", "nan", "nan", "SNHG7", "CD7", "nan"], - ["nan", "SRSF7", "IL32", "GZMA", "nan", "LST1", "IGJ", "nan", "CTSW", "nan"], - [ - "nan", - "nan", - "CD2", - "CTSW", - "CD8B", - "TYROBP", - "ISG20", - "SNHG8", - "GZMB", - "nan", - ], - ] -) - -names_reference = np.array( + *("CD3D", "ITM2A", "HLA-DRB1", "CCL5", "HLA-DPA1"), + *("nan", "CD79A", "nan", "NKG7", "LYZ"), + ], [ - ["CD3D", "ITM2A", "CD3D", "nan", "CD3D", "nan", "CD79A", "nan", "CD7"], - ["nan", "nan", "nan", "CD3D", "nan", "AIF1", "nan", "nan", "NKG7"], - ["nan", "nan", "nan", "NKG7", "nan", "FCGR3A", "ISG20", "SNHG7", "CTSW"], - ["nan", "CD3D", "nan", "CCL5", "CD7", "nan", "CD79B", "nan", "GNLY"], - ["CD3E", "IL32", "nan", "IL32", "CD27", "FCER1G", "nan", "nan", "nan"], - ] -) - -names_compare_abs = np.array( + *("HLA-DPA1", "nan", "CD3D", "NKG7", "HLA-DRB1"), + *("AIF1", "CD79B", "nan", "GNLY", "CST3"), + ], [ - [ - "CD3D", - "ITM2A", - "HLA-DRB1", - "CCL5", - "HLA-DPA1", - "nan", - "CD79A", - "nan", - "NKG7", - "LYZ", - ], - [ - "HLA-DPA1", - "nan", - "CD3D", - "NKG7", - "HLA-DRB1", - "AIF1", - "CD79B", - "nan", - "GNLY", - "CST3", - ], - [ - "nan", - "PSAP", - "CD74", - "CST7", - "CD74", - "PSAP", - "FCER1G", - "SNHG7", - "CD7", - "HLA-DRA", - ], - [ - "IL32", - "nan", - "HLA-DRB5", - "GZMA", - "HLA-DRB5", - "LST1", - "nan", - "nan", - "CTSW", - "HLA-DRB1", - ], - [ - "nan", - "FCER1G", - "HLA-DPB1", - "CTSW", - "HLA-DPB1", - "TYROBP", - "TYROBP", - "S100A10", - "GZMB", - "HLA-DPA1", - ], - ] -) - - -def test_filter_rank_genes_groups(): - adata = pbmc68k_reduced() - - # fix filter defaults - args = { - "adata": adata, - "key_added": "rank_genes_groups_filtered", - "min_in_group_fraction": 0.25, - "min_fold_change": 1, - "max_out_group_fraction": 0.5, - } - - rank_genes_groups( - adata, "bulk_labels", reference="Dendritic", method="wilcoxon", n_genes=5 - ) - filter_rank_genes_groups(**args) - - assert np.array_equal( - names_reference, - np.array(adata.uns["rank_genes_groups_filtered"]["names"].tolist()), - ) + *("nan", "PSAP", "CD74", "CST7", "CD74"), + *("PSAP", "FCER1G", "SNHG7", "CD7", "HLA-DRA"), + ], + [ + *("IL32", "nan", "HLA-DRB5", "GZMA", "HLA-DRB5"), + *("LST1", "nan", "nan", "CTSW", "HLA-DRB1"), + ], + [ + *("nan", "FCER1G", "HLA-DPB1", "CTSW", "HLA-DPB1"), + *("TYROBP", "TYROBP", "S100A10", "GZMB", "HLA-DPA1"), + ], +] - rank_genes_groups(adata, "bulk_labels", method="wilcoxon", n_genes=5) - filter_rank_genes_groups(**args) - assert np.array_equal( - names_no_reference, - np.array(adata.uns["rank_genes_groups_filtered"]["names"].tolist()), - ) +EXPECTED = { + ("Dendritic", False): np.array(NAMES_REF), + ("rest", False): np.array(NAMES_NO_REF), + ("rest", True): np.array(NAMES_NO_REF_COMPARE_ABS), +} - rank_genes_groups(adata, "bulk_labels", method="wilcoxon", pts=True, n_genes=5) - filter_rank_genes_groups(**args) - assert np.array_equal( - names_no_reference, - np.array(adata.uns["rank_genes_groups_filtered"]["names"].tolist()), - ) +@pytest.mark.parametrize( + ("reference", "pts", "abs"), + [ + pytest.param("Dendritic", False, False, id="ref-no_pts-no_abs"), + pytest.param("rest", False, False, id="rest-no_pts-no_abs"), + pytest.param("rest", True, False, id="rest-pts-no_abs"), + pytest.param("rest", True, True, id="rest-pts-abs"), + ], +) +def test_filter_rank_genes_groups(reference, pts, abs): + adata = pbmc68k_reduced() - # test compare_abs rank_genes_groups( - adata, "bulk_labels", method="wilcoxon", pts=True, rankby_abs=True, n_genes=5 - ) - - filter_rank_genes_groups( adata, - compare_abs=True, - min_in_group_fraction=-1, - max_out_group_fraction=1, - min_fold_change=3.1, + "bulk_labels", + reference=reference, + pts=pts, + method="wilcoxon", + rankby_abs=abs, + n_genes=5, ) + if abs: + filter_rank_genes_groups( + adata, + compare_abs=True, + min_in_group_fraction=-1, + max_out_group_fraction=1, + min_fold_change=3.1, + ) + else: + filter_rank_genes_groups( + adata, + min_in_group_fraction=0.25, + min_fold_change=1, + max_out_group_fraction=0.5, + ) assert np.array_equal( - names_compare_abs, + EXPECTED[reference, abs], np.array(adata.uns["rank_genes_groups_filtered"]["names"].tolist()), ) diff --git a/tests/test_preprocessing_distributed.py b/tests/test_preprocessing_distributed.py index a1b99121ef..afb120b982 100644 --- a/tests/test_preprocessing_distributed.py +++ b/tests/test_preprocessing_distributed.py @@ -40,13 +40,13 @@ def adata() -> AnnData: return a -@filter_oldformatwarning @pytest.fixture( params=[ pytest.param("direct", marks=[needs.zappy]), pytest.param("dask", marks=[needs.dask, pytest.mark.anndata_dask_support]), ] ) +@filter_oldformatwarning def adata_dist(request: pytest.FixtureRequest) -> AnnData: # regular anndata except for X, which we replace on the next line a = read_zarr(input_file) @@ -75,6 +75,7 @@ def test_log1p(adata: AnnData, adata_dist: AnnData): npt.assert_allclose(result, adata.X) +@pytest.mark.filterwarnings("ignore:Use sc.pp.normalize_total instead:FutureWarning") def test_normalize_per_cell( request: pytest.FixtureRequest, adata: AnnData, adata_dist: AnnData ): From ac4c629ba1b50642618e4b632a21e5de903ce8ec Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 19 Dec 2024 19:59:59 +0100 Subject: [PATCH 05/24] Deprecate visium (#3407) --- docs/conf.py | 1 + docs/release-notes/1.5.0.md | 2 +- docs/release-notes/3407.misc.md | 7 + docs/tutorials/index.md | 22 +- docs/tutorials/spatial/basic-analysis.ipynb | 1 - docs/tutorials/spatial/index.md | 8 - .../spatial/integration-scanorama.ipynb | 1 - src/scanpy/datasets/_datasets.py | 6 +- src/scanpy/plotting/_tools/scatterplots.py | 7 +- src/scanpy/readwrite.py | 6 +- tests/test_datasets.py | 3 + tests/test_embedding_plots.py | 566 ------------------ tests/test_plotting.py | 5 +- tests/test_plotting_embedded/conftest.py | 66 ++ .../test_plotting_embedded/test_embeddings.py | 253 ++++++++ tests/test_plotting_embedded/test_spatial.py | 267 +++++++++ tests/test_read_10x.py | 1 + 17 files changed, 625 insertions(+), 597 deletions(-) create mode 100644 docs/release-notes/3407.misc.md delete mode 120000 docs/tutorials/spatial/basic-analysis.ipynb delete mode 100644 docs/tutorials/spatial/index.md delete mode 120000 docs/tutorials/spatial/integration-scanorama.ipynb delete mode 100644 tests/test_embedding_plots.py create mode 100644 tests/test_plotting_embedded/conftest.py create mode 100644 tests/test_plotting_embedded/test_embeddings.py create mode 100644 tests/test_plotting_embedded/test_spatial.py diff --git a/docs/conf.py b/docs/conf.py index 155869b360..e17aa9df0f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -143,6 +143,7 @@ scipy=("https://docs.scipy.org/doc/scipy/", None), seaborn=("https://seaborn.pydata.org/", None), session_info2=("https://session-info2.readthedocs.io/en/stable/", None), + squidpy=("https://squidpy.readthedocs.io/en/stable/", None), sklearn=("https://scikit-learn.org/stable/", None), ) diff --git a/docs/release-notes/1.5.0.md b/docs/release-notes/1.5.0.md index 922e758723..956ceb9493 100644 --- a/docs/release-notes/1.5.0.md +++ b/docs/release-notes/1.5.0.md @@ -5,7 +5,7 @@ The `1.5.0` release adds a lot of new functionality, much of which takes advanta #### Spatial data support -- Basic analysis {doc}`/tutorials/spatial/basic-analysis` and integration with single cell data {doc}`/tutorials/spatial/integration-scanorama` {smaller}`G Palla` +- Tutorials for basic analysis and integration with single cell data {smaller}`G Palla` - {func}`~scanpy.read_visium` read 10x Visium data {pr}`1034` {smaller}`G Palla, P Angerer, I Virshup` - {func}`~scanpy.datasets.visium_sge` load Visium data directly from 10x Genomics {pr}`1013` {smaller}`M Mirkazemi, G Palla, P Angerer` - {func}`~scanpy.pl.spatial` plot spatial data {pr}`1012` {smaller}`G Palla, P Angerer` diff --git a/docs/release-notes/3407.misc.md b/docs/release-notes/3407.misc.md new file mode 100644 index 0000000000..4670de6fd4 --- /dev/null +++ b/docs/release-notes/3407.misc.md @@ -0,0 +1,7 @@ +| Deprecate … | in favor of … | +| --- | --- | +| {func}`scanpy.read_visium` | {func}`squidpy.read.visium` | +| {func}`scanpy.datasets.visium_sge` | {func}`squidpy.datasets.visium` | +| {func}`scanpy.pl.spatial` | {func}`squidpy.pl.spatial_scatter` | + +{smaller}`P Angerer` diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index ee57056a6d..b20ee2b762 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -37,19 +37,6 @@ trajectories/index ## Spatial data -```{seealso} -For more up-to-date tutorials on working with spatial data, see: - -* [SquidPy tutorials](https://squidpy.readthedocs.io/en/stable/notebooks/tutorials/index.html) -* [SpatialData tutorials](https://spatialdata.scverse.org/en/latest/tutorials/notebooks/notebooks.html) -* [Scverse ecosystem spatial tutorials](https://scverse.org/learn/) -``` - -```{toctree} -:maxdepth: 2 - -spatial/index -``` ## Experimental @@ -64,3 +51,12 @@ experimental/index A number of older tutorials can be found at: * The [`scanpy_usage`](https://github.com/scverse/scanpy_usage) repository + +```{seealso} +Scanpy used to have tutorials for its (now deprecated) spatial data functionality.x +For up-to-date tutorials on working with spatial data, see: + +* SquidPy {doc}`squidpy:notebooks/tutorials/index` +* [SpatialData tutorials](https://spatialdata.scverse.org/en/latest/tutorials/notebooks/notebooks.html) +* [Scverse ecosystem spatial tutorials](https://scverse.org/learn/) +``` diff --git a/docs/tutorials/spatial/basic-analysis.ipynb b/docs/tutorials/spatial/basic-analysis.ipynb deleted file mode 120000 index 66d9e48121..0000000000 --- a/docs/tutorials/spatial/basic-analysis.ipynb +++ /dev/null @@ -1 +0,0 @@ -../../../notebooks/spatial/basic-analysis.ipynb \ No newline at end of file diff --git a/docs/tutorials/spatial/index.md b/docs/tutorials/spatial/index.md deleted file mode 100644 index 801b901e53..0000000000 --- a/docs/tutorials/spatial/index.md +++ /dev/null @@ -1,8 +0,0 @@ -## Spatial - -```{toctree} -:maxdepth: 1 - -basic-analysis -integration-scanorama -``` diff --git a/docs/tutorials/spatial/integration-scanorama.ipynb b/docs/tutorials/spatial/integration-scanorama.ipynb deleted file mode 120000 index 5143681577..0000000000 --- a/docs/tutorials/spatial/integration-scanorama.ipynb +++ /dev/null @@ -1 +0,0 @@ -../../../notebooks/spatial/integration-scanorama.ipynb \ No newline at end of file diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index df510b3209..8859de4d74 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -9,7 +9,7 @@ from anndata import AnnData from .. import _utils -from .._compat import old_positionals +from .._compat import deprecated, old_positionals from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs from ..readwrite import read, read_visium @@ -509,6 +509,7 @@ def _download_visium_dataset( return sample_dir +@deprecated("Use `squidpy.datasets.visium` instead.") @doctest_internet @check_datasetdir_exists def visium_sge( @@ -519,6 +520,9 @@ def visium_sge( """\ Processed Visium Spatial Gene Expression data from 10x Genomics’ database. + .. deprecated:: 1.11.0 + Use :func:`squidpy.datasets.visium` instead. + The database_ can be browsed online to find the ``sample_id`` you want. .. _database: https://support.10xgenomics.com/spatial-gene-expression/datasets diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 4ce39f7211..e2564eb17f 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -28,6 +28,7 @@ from packaging.version import Version from ... import logging as logg +from ..._compat import deprecated from ..._settings import settings from ..._utils import ( Empty, # noqa: TCH001 @@ -919,6 +920,7 @@ def pca( return axs +@deprecated("Use `squidpy.pl.spatial_scatter` instead.") @_wraps_plot_scatter @_doc_params( adata_color_etc=doc_adata_color_etc, @@ -948,6 +950,9 @@ def spatial( """\ Scatter plot in spatial coordinates. + .. deprecated:: 1.11.0 + Use :func:`squidpy.pl.spatial_scatter` instead. + This function allows overlaying data on top of images. Use the parameter `img_key` to see the image in the background And the parameter `library_id` to select the image. @@ -994,8 +999,6 @@ def spatial( -------- :func:`scanpy.datasets.visium_sge` Example visium data. - :doc:`/tutorials/spatial/basic-analysis` - Tutorial on spatial analysis. """ # get default image params if available library_id, spatial_data = _check_spatial_data(adata.uns, library_id) diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 07bd817ca5..3333fbc0a1 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -36,7 +36,7 @@ from matplotlib.image import imread from . import logging as logg -from ._compat import add_note, old_positionals +from ._compat import add_note, deprecated, old_positionals from ._settings import settings from ._utils import _empty @@ -366,6 +366,7 @@ def _read_v3_10x_h5(filename, *, start=None): raise Exception("File is missing one or more required datasets.") +@deprecated("Use `squidpy.read.visium` instead.") def read_visium( path: Path | str, genome: str | None = None, @@ -378,6 +379,9 @@ def read_visium( """\ Read 10x-Genomics-formatted visum dataset. + .. deprecated:: 1.11.0 + Use :func:`squidpy.read.visium` instead. + In addition to reading regular 10x output, this looks for the `spatial` folder and loads images, coordinates and scale factors. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4bad3800d7..5e0fc1e125 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -111,6 +111,7 @@ def test_pbmc68k_reduced(): sc.datasets.pbmc68k_reduced() +@pytest.mark.filterwarnings("ignore:Use `squidpy.*` instead:FutureWarning") @pytest.mark.internet def test_visium_datasets(): """Tests that reading/ downloading works and is does not have global effects.""" @@ -121,6 +122,7 @@ def test_visium_datasets(): assert_adata_equal(hheart, hheart_again) +@pytest.mark.filterwarnings("ignore:Use `squidpy.*` instead:FutureWarning") @pytest.mark.internet def test_visium_datasets_dir_change(tmp_path: Path): """Test that changing the dataset dir doesn't break reading.""" @@ -132,6 +134,7 @@ def test_visium_datasets_dir_change(tmp_path: Path): assert_adata_equal(mbrain, mbrain_again) +@pytest.mark.filterwarnings("ignore:Use `squidpy.*` instead:FutureWarning") @pytest.mark.internet def test_visium_datasets_images(): """Test that image download works and is does not have global effects.""" diff --git a/tests/test_embedding_plots.py b/tests/test_embedding_plots.py deleted file mode 100644 index d48f44b2b6..0000000000 --- a/tests/test_embedding_plots.py +++ /dev/null @@ -1,566 +0,0 @@ -from __future__ import annotations - -from functools import partial -from pathlib import Path -from typing import TYPE_CHECKING - -import matplotlib as mpl -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import pytest -import seaborn as sns -from matplotlib.colors import Normalize -from matplotlib.testing.compare import compare_images - -import scanpy as sc -from testing.scanpy._helpers.data import pbmc3k_processed - -if TYPE_CHECKING: - from scanpy.plotting._utils import _LegendLoc - - -HERE: Path = Path(__file__).parent -ROOT = HERE / "_images" - -MISSING_VALUES_ROOT = ROOT / "embedding-missing-values" - - -def check_images(pth1, pth2, *, tol): - result = compare_images(pth1, pth2, tol=tol) - assert result is None, result - - -@pytest.fixture(scope="module") -def adata(): - """A bit cute.""" - from matplotlib.image import imread - from sklearn.cluster import DBSCAN - from sklearn.datasets import make_blobs - - empty_pixel = np.array([1.0, 1.0, 1.0, 0]).reshape(1, 1, -1) - image = imread(HERE.parent / "docs/_static/img/Scanpy_Logo_RGB.png") - x, y = np.where(np.logical_and.reduce(~np.equal(image, empty_pixel), axis=2)) - - # Just using to calculate the hex coords - hexes = plt.hexbin(x, y, gridsize=(44, 100)) - counts = hexes.get_array() - pixels = hexes.get_offsets()[counts != 0] - plt.close() - - labels = DBSCAN(eps=20, min_samples=2).fit(pixels).labels_ - order = np.argsort(labels) - adata = sc.AnnData( - make_blobs( - pd.Series(labels[order]).value_counts().values, - n_features=20, - shuffle=False, - random_state=42, - )[0], - obs={"label": pd.Categorical(labels[order].astype(str))}, - obsm={"spatial": pixels[order, ::-1]}, - uns={ - "spatial": { - "scanpy_img": { - "images": {"hires": image}, - "scalefactors": { - "tissue_hires_scalef": 1, - "spot_diameter_fullres": 10, - }, - } - } - }, - ) - sc.pp.pca(adata) - - # Adding some missing values - adata.obs["label_missing"] = adata.obs["label"].copy() - adata.obs["label_missing"][::2] = np.nan - - adata.obs["1_missing"] = adata.obs_vector("1") - adata.obs.loc[ - adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing" - ] = np.nan - - return adata - - -@pytest.fixture -def fixture_request(request): - """Returns a Request object. - - Allows you to access names of parameterized tests from within a test. - """ - return request - - -@pytest.fixture( - params=[(0, 0, 0, 1), None], - ids=["na_color.black_tup", "na_color.default"], -) -def na_color(request): - return request.param - - -@pytest.fixture(params=[True, False], ids=["na_in_legend.True", "na_in_legend.False"]) -def na_in_legend(request): - return request.param - - -@pytest.fixture( - params=[partial(sc.pl.pca, show=False), partial(sc.pl.spatial, show=False)], - ids=["pca", "spatial"], -) -def plotfunc(request): - return request.param - - -@pytest.fixture( - params=["on data", "right margin", "lower center", None], - ids=["legend.on_data", "legend.on_right", "legend.on_bottom", "legend.off"], -) -def legend_loc(request) -> _LegendLoc | None: - return request.param - - -@pytest.fixture( - params=[lambda x: list(x.cat.categories[:3]), lambda x: []], - ids=["groups.3", "groups.all"], -) -def groupsfunc(request): - return request.param - - -@pytest.fixture( - params=[ - pytest.param( - {"vmin": None, "vmax": None, "vcenter": None, "norm": None}, - id="vbounds.default", - ), - pytest.param( - {"vmin": 0, "vmax": 5, "vcenter": None, "norm": None}, id="vbounds.numbers" - ), - pytest.param( - {"vmin": "p15", "vmax": "p90", "vcenter": None, "norm": None}, - id="vbounds.percentile", - ), - pytest.param( - {"vmin": 0, "vmax": "p99", "vcenter": 0.1, "norm": None}, - id="vbounds.vcenter", - ), - pytest.param( - {"vmin": None, "vmax": None, "vcenter": None, "norm": Normalize(0, 5)}, - id="vbounds.norm", - ), - ] -) -def vbounds(request): - return request.param - - -def test_missing_values_categorical( - *, - fixture_request: pytest.FixtureRequest, - image_comparer, - adata, - plotfunc, - na_color, - na_in_legend, - legend_loc, - groupsfunc, -): - save_and_compare_images = partial(image_comparer, MISSING_VALUES_ROOT, tol=15) - - base_name = fixture_request.node.name - - # Passing through a dict so it's easier to use default values - kwargs = {} - kwargs["legend_loc"] = legend_loc - kwargs["groups"] = groupsfunc(adata.obs["label"]) - if na_color is not None: - kwargs["na_color"] = na_color - kwargs["na_in_legend"] = na_in_legend - - plotfunc(adata, color=["label", "label_missing"], **kwargs) - - save_and_compare_images(base_name) - - -def test_missing_values_continuous( - *, - fixture_request: pytest.FixtureRequest, - image_comparer, - adata, - plotfunc, - na_color, - vbounds, -): - save_and_compare_images = partial(image_comparer, MISSING_VALUES_ROOT, tol=15) - - base_name = fixture_request.node.name - - # Passing through a dict so it's easier to use default values - kwargs = {} - kwargs.update(vbounds) - if na_color is not None: - kwargs["na_color"] = na_color - - plotfunc(adata, color=["1", "1_missing"], **kwargs) - - save_and_compare_images(base_name) - - -def test_enumerated_palettes(fixture_request, adata, tmpdir, plotfunc): - tmpdir = Path(tmpdir) - base_name = fixture_request.node.name - - categories = adata.obs["label"].cat.categories - colors_rgb = dict(zip(categories, sns.color_palette(n_colors=12))) - - dict_pth = tmpdir / f"rgbdict_{base_name}.png" - list_pth = tmpdir / f"rgblist_{base_name}.png" - - # making a copy so colors aren't saved - plotfunc(adata.copy(), color="label", palette=colors_rgb) - plt.savefig(dict_pth, dpi=40) - plt.close() - plotfunc(adata.copy(), color="label", palette=[colors_rgb[c] for c in categories]) - plt.savefig(list_pth, dpi=40) - plt.close() - - check_images(dict_pth, list_pth, tol=15) - - -def test_dimension_broadcasting(adata, tmpdir, check_same_image): - tmpdir = Path(tmpdir) - - with pytest.raises( - ValueError, - match=r"Could not broadcast together arguments with shapes: \[2, 3, 1\]", - ): - sc.pl.pca( - adata, color=["label", "1_missing"], dimensions=[(0, 1), (1, 2), (2, 3)] - ) - - dims_pth = tmpdir / "broadcast_dims.png" - color_pth = tmpdir / "broadcast_colors.png" - - sc.pl.pca(adata, color=["label", "label", "label"], dimensions=(2, 3), show=False) - plt.savefig(dims_pth, dpi=40) - plt.close() - sc.pl.pca(adata, color="label", dimensions=[(2, 3), (2, 3), (2, 3)], show=False) - plt.savefig(color_pth, dpi=40) - plt.close() - - check_same_image(dims_pth, color_pth, tol=5) - - -def test_marker_broadcasting(adata, tmpdir, check_same_image): - tmpdir = Path(tmpdir) - - with pytest.raises( - ValueError, - match=r"Could not broadcast together arguments with shapes: \[2, 1, 3\]", - ): - sc.pl.pca(adata, color=["label", "1_missing"], marker=[".", "^", "x"]) - - dims_pth = tmpdir / "broadcast_markers.png" - color_pth = tmpdir / "broadcast_colors_for_markers.png" - - sc.pl.pca(adata, color=["label", "label", "label"], marker="^", show=False) - plt.savefig(dims_pth, dpi=40) - plt.close() - sc.pl.pca(adata, color="label", marker=["^", "^", "^"], show=False) - plt.savefig(color_pth, dpi=40) - plt.close() - - check_same_image(dims_pth, color_pth, tol=5) - - -def test_dimensions_same_as_components(adata, tmpdir, check_same_image): - tmpdir = Path(tmpdir) - adata = adata.copy() - adata.obs["mean"] = np.ravel(adata.X.mean(axis=1)) - - comp_pth = tmpdir / "components_plot.png" - dims_pth = tmpdir / "dimension_plot.png" - - # TODO: Deprecate components kwarg - # with pytest.warns(FutureWarning, match=r"components .* deprecated"): - sc.pl.pca( - adata, - color=["mean", "label"], - components=["1,2", "2,3"], - show=False, - ) - plt.savefig(comp_pth, dpi=40) - plt.close() - - sc.pl.pca( - adata, - color=["mean", "mean", "label", "label"], - dimensions=[(0, 1), (1, 2), (0, 1), (1, 2)], - show=False, - ) - plt.savefig(dims_pth, dpi=40) - plt.close() - - check_same_image(dims_pth, comp_pth, tol=5) - - -def test_embedding_colorbar_location(image_comparer): - save_and_compare_images = partial(image_comparer, ROOT, tol=15) - - adata = pbmc3k_processed().raw.to_adata() - - sc.pl.pca(adata, color="LDHB", colorbar_loc=None) - - save_and_compare_images("no_colorbar") - - -# Spatial specific - - -def test_visium_circles(image_comparer): # standard visium data - save_and_compare_images = partial(image_comparer, ROOT, tol=15) - - adata = sc.read_visium(HERE / "_data" / "visium_data" / "1.0.0") - adata.obs = adata.obs.astype({"array_row": "str"}) - - sc.pl.spatial( - adata, - color="array_row", - groups=["24", "33"], - crop_coord=(100, 400, 400, 100), - alpha=0.5, - size=1.3, - show=False, - ) - - save_and_compare_images("spatial_visium") - - -def test_visium_default(image_comparer): # default values - from packaging.version import parse as parse_version - - if parse_version(mpl.__version__) < parse_version("3.7.0"): - pytest.xfail("Matplotlib 3.7.0+ required for this test") - - save_and_compare_images = partial(image_comparer, ROOT, tol=5) - - adata = sc.read_visium(HERE / "_data" / "visium_data" / "1.0.0") - adata.obs = adata.obs.astype({"array_row": "str"}) - - # Points default to transparent if an image is included - sc.pl.spatial(adata, show=False) - - save_and_compare_images("spatial_visium_default") - - -def test_visium_empty_img_key(image_comparer): # visium coordinates but image empty - save_and_compare_images = partial(image_comparer, ROOT, tol=15) - - adata = sc.read_visium(HERE / "_data" / "visium_data" / "1.0.0") - adata.obs = adata.obs.astype({"array_row": "str"}) - - sc.pl.spatial(adata, img_key=None, color="array_row", show=False) - - save_and_compare_images("spatial_visium_empty_image") - - sc.pl.embedding(adata, basis="spatial", color="array_row", show=False) - save_and_compare_images("spatial_visium_embedding") - - -def test_spatial_general(image_comparer): # general coordinates - save_and_compare_images = partial(image_comparer, ROOT, tol=15) - - adata = sc.read_visium(HERE / "_data" / "visium_data" / "1.0.0") - adata.obs = adata.obs.astype({"array_row": "str"}) - spatial_metadata = adata.uns.pop( - "spatial" - ) # spatial data don't have imgs, so remove entry from uns - # Required argument for now - spot_size = list(spatial_metadata.values())[0]["scalefactors"][ - "spot_diameter_fullres" - ] - - sc.pl.spatial(adata, show=False, spot_size=spot_size) - save_and_compare_images("spatial_general_nocol") - - # category - sc.pl.spatial(adata, show=False, spot_size=spot_size, color="array_row") - save_and_compare_images("spatial_general_cat") - - # continuous - sc.pl.spatial(adata, show=False, spot_size=spot_size, color="array_col") - save_and_compare_images("spatial_general_cont") - - -def test_spatial_external_img(image_comparer): # external image - save_and_compare_images = partial(image_comparer, ROOT, tol=15) - - adata = sc.read_visium(HERE / "_data" / "visium_data" / "1.0.0") - adata.obs = adata.obs.astype({"array_row": "str"}) - - img = adata.uns["spatial"]["custom"]["images"]["hires"] - scalef = adata.uns["spatial"]["custom"]["scalefactors"]["tissue_hires_scalef"] - sc.pl.spatial( - adata, - color="array_row", - scale_factor=scalef, - img=img, - basis="spatial", - show=False, - ) - save_and_compare_images("spatial_external_img") - - -@pytest.fixture(scope="module") -def equivalent_spatial_plotters(adata): - no_spatial = adata.copy() - del no_spatial.uns["spatial"] - - img_key = "hires" - library_id = list(adata.uns["spatial"])[0] - spatial_data = adata.uns["spatial"][library_id] - img = spatial_data["images"][img_key] - scale_factor = spatial_data["scalefactors"][f"tissue_{img_key}_scalef"] - spot_size = spatial_data["scalefactors"]["spot_diameter_fullres"] - - orig_plotter = partial(sc.pl.spatial, adata, color="1", show=False) - removed_plotter = partial( - sc.pl.spatial, - no_spatial, - color="1", - img=img, - scale_factor=scale_factor, - spot_size=spot_size, - show=False, - ) - - return (orig_plotter, removed_plotter) - - -@pytest.fixture(scope="module") -def equivalent_spatial_plotters_no_img(equivalent_spatial_plotters): - orig, removed = equivalent_spatial_plotters - return (partial(orig, img_key=None), partial(removed, img=None, scale_factor=None)) - - -@pytest.fixture( - params=[ - pytest.param({"crop_coord": (50, 200, 0, 500)}, id="crop"), - pytest.param({"size": 0.5}, id="size:.5"), - pytest.param({"size": 2}, id="size:2"), - pytest.param({"spot_size": 5}, id="spotsize"), - pytest.param({"bw": True}, id="bw"), - # Shape of the image for particular fixture, should not be hardcoded like this - pytest.param({"img": np.ones((774, 1755, 4)), "scale_factor": 1.0}, id="img"), - pytest.param( - {"na_color": (0, 0, 0, 0), "color": "1_missing"}, id="na_color.transparent" - ), - pytest.param( - {"na_color": "lightgray", "color": "1_missing"}, id="na_color.lightgray" - ), - ] -) -def spatial_kwargs(request): - return request.param - - -def test_manual_equivalency(equivalent_spatial_plotters, tmpdir, spatial_kwargs): - """ - Tests that manually passing values to sc.pl.spatial is similar to automatic extraction. - """ - orig, removed = equivalent_spatial_plotters - - TESTDIR = Path(tmpdir) - orig_pth = TESTDIR / "orig.png" - removed_pth = TESTDIR / "removed.png" - - orig(**spatial_kwargs) - plt.savefig(orig_pth, dpi=40) - plt.close() - removed(**spatial_kwargs) - plt.savefig(removed_pth, dpi=40) - plt.close() - - check_images(orig_pth, removed_pth, tol=1) - - -def test_manual_equivalency_no_img( - equivalent_spatial_plotters_no_img, tmpdir, spatial_kwargs -): - if "bw" in spatial_kwargs: - # Has no meaning when there is no image - pytest.skip() - orig, removed = equivalent_spatial_plotters_no_img - - TESTDIR = Path(tmpdir) - orig_pth = TESTDIR / "orig.png" - removed_pth = TESTDIR / "removed.png" - - orig(**spatial_kwargs) - plt.savefig(orig_pth, dpi=40) - plt.close() - removed(**spatial_kwargs) - plt.savefig(removed_pth, dpi=40) - plt.close() - - check_images(orig_pth, removed_pth, tol=1) - - -def test_white_background_vs_no_img(adata, tmpdir, spatial_kwargs): - if {"bw", "img", "img_key", "na_color"}.intersection(spatial_kwargs): - # These arguments don't make sense for this check - pytest.skip() - - white_background = np.ones_like( - adata.uns["spatial"]["scanpy_img"]["images"]["hires"] - ) - TESTDIR = Path(tmpdir) - white_pth = TESTDIR / "white_background.png" - noimg_pth = TESTDIR / "no_img.png" - - sc.pl.spatial( - adata, - color="2", - img=white_background, - scale_factor=1.0, - show=False, - **spatial_kwargs, - ) - plt.savefig(white_pth) - sc.pl.spatial(adata, color="2", img_key=None, show=False, **spatial_kwargs) - plt.savefig(noimg_pth) - - check_images(white_pth, noimg_pth, tol=1) - - -def test_spatial_na_color(adata, tmpdir): - """ - Check that na_color defaults to transparent when an image is present, light gray when not. - """ - white_background = np.ones_like( - adata.uns["spatial"]["scanpy_img"]["images"]["hires"] - ) - TESTDIR = Path(tmpdir) - lightgray_pth = TESTDIR / "lightgray.png" - transparent_pth = TESTDIR / "transparent.png" - noimg_pth = TESTDIR / "noimg.png" - whiteimg_pth = TESTDIR / "whiteimg.png" - - def plot(pth, **kwargs): - sc.pl.spatial(adata, color="1_missing", show=False, **kwargs) - plt.savefig(pth, dpi=40) - plt.close() - - plot(lightgray_pth, na_color="lightgray", img_key=None) - plot(transparent_pth, na_color=(0.0, 0.0, 0.0, 0.0), img_key=None) - plot(noimg_pth, img_key=None) - plot(whiteimg_pth, img=white_background, scale_factor=1.0) - - check_images(lightgray_pth, noimg_pth, tol=1) - check_images(transparent_pth, whiteimg_pth, tol=1) - with pytest.raises(AssertionError): - check_images(lightgray_pth, transparent_pth, tol=1) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 2f0f5f60cd..f135a68aa4 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1456,11 +1456,10 @@ def test_rankings(image_comparer): # TODO: Make more generic -def test_scatter_rep(tmpdir): +def test_scatter_rep(tmp_path): """ Test to make sure I can predict when scatter reps should be the same """ - TESTDIR = Path(tmpdir) rep_args = { "raw": {"use_raw": True}, "layer": {"layer": "layer", "use_raw": False}, @@ -1475,7 +1474,7 @@ def test_scatter_rep(tmpdir): columns=["rep", "gene", "result"], ) states["outpth"] = [ - TESTDIR / f"{state.gene}_{state.rep}_{state.result}.png" + tmp_path / f"{state.gene}_{state.rep}_{state.result}.png" for state in states.itertuples() ] pattern = np.array(list(chain.from_iterable(repeat(i, 5) for i in range(3)))) diff --git a/tests/test_plotting_embedded/conftest.py b/tests/test_plotting_embedded/conftest.py new file mode 100644 index 0000000000..d9e8ff8581 --- /dev/null +++ b/tests/test_plotting_embedded/conftest.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +import scanpy as sc + +HERE: Path = Path(__file__).parent + + +@pytest.fixture(scope="module") +def adata(): + """A bit cute.""" + from matplotlib.image import imread + from sklearn.cluster import DBSCAN + from sklearn.datasets import make_blobs + + empty_pixel = np.array([1.0, 1.0, 1.0, 0]).reshape(1, 1, -1) + image = imread(HERE.parent.parent / "docs/_static/img/Scanpy_Logo_RGB.png") + x, y = np.where(np.logical_and.reduce(~np.equal(image, empty_pixel), axis=2)) + + # Just using to calculate the hex coords + hexes = plt.hexbin(x, y, gridsize=(44, 100)) + counts = hexes.get_array() + pixels = hexes.get_offsets()[counts != 0] + plt.close() + + labels = DBSCAN(eps=20, min_samples=2).fit(pixels).labels_ + order = np.argsort(labels) + adata = sc.AnnData( + make_blobs( + pd.Series(labels[order]).value_counts().values, + n_features=20, + shuffle=False, + random_state=42, + )[0], + obs={"label": pd.Categorical(labels[order].astype(str))}, + obsm={"spatial": pixels[order, ::-1]}, + uns={ + "spatial": { + "scanpy_img": { + "images": {"hires": image}, + "scalefactors": { + "tissue_hires_scalef": 1, + "spot_diameter_fullres": 10, + }, + } + } + }, + ) + sc.pp.pca(adata) + + # Adding some missing values + adata.obs["label_missing"] = adata.obs["label"].copy() + adata.obs.loc[::2, "label_missing"] = np.nan + + adata.obs["1_missing"] = adata.obs_vector("1") + adata.obs.loc[ + adata.obsm["spatial"][:, 0] < adata.obsm["spatial"][:, 0].mean(), "1_missing" + ] = np.nan + + return adata diff --git a/tests/test_plotting_embedded/test_embeddings.py b/tests/test_plotting_embedded/test_embeddings.py new file mode 100644 index 0000000000..c5dc8d3e53 --- /dev/null +++ b/tests/test_plotting_embedded/test_embeddings.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from functools import partial, wraps +from pathlib import Path +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import seaborn as sns +from matplotlib.colors import Normalize +from matplotlib.testing.compare import compare_images + +import scanpy as sc +from testing.scanpy._helpers.data import pbmc3k_processed + +if TYPE_CHECKING: + from scanpy.plotting._utils import _LegendLoc + + +HERE: Path = Path(__file__).parent +ROOT = HERE.parent / "_images" + +MISSING_VALUES_ROOT = ROOT / "embedding-missing-values" + + +def check_images(pth1: Path, pth2: Path, *, tol: int) -> None: + result = compare_images(str(pth1), str(pth2), tol=tol) + assert result is None, result + + +@pytest.fixture( + params=[(0, 0, 0, 1), None], + ids=["na_color.black_tup", "na_color.default"], +) +def na_color(request): + return request.param + + +@pytest.fixture(params=[True, False], ids=["na_in_legend.True", "na_in_legend.False"]) +def na_in_legend(request): + return request.param + + +@pytest.fixture(params=[sc.pl.pca, sc.pl.spatial]) +def plotfunc(request): + if request.param is sc.pl.spatial: + + @wraps(request.param) + def f(adata, **kwargs): + with pytest.warns(FutureWarning, match=r"Use `squidpy.*` instead"): + return sc.pl.spatial(adata, **kwargs) + + else: + f = request.param + return partial(f, show=False) + + +@pytest.fixture( + params=["on data", "right margin", "lower center", None], + ids=["legend.on_data", "legend.on_right", "legend.on_bottom", "legend.off"], +) +def legend_loc(request) -> _LegendLoc | None: + return request.param + + +@pytest.fixture( + params=[lambda x: list(x.cat.categories[:3]), lambda x: []], + ids=["groups.3", "groups.all"], +) +def groupsfunc(request): + return request.param + + +@pytest.fixture( + params=[ + pytest.param( + {"vmin": None, "vmax": None, "vcenter": None, "norm": None}, + id="vbounds.default", + ), + pytest.param( + {"vmin": 0, "vmax": 5, "vcenter": None, "norm": None}, id="vbounds.numbers" + ), + pytest.param( + {"vmin": "p15", "vmax": "p90", "vcenter": None, "norm": None}, + id="vbounds.percentile", + ), + pytest.param( + {"vmin": 0, "vmax": "p99", "vcenter": 0.1, "norm": None}, + id="vbounds.vcenter", + ), + pytest.param( + {"vmin": None, "vmax": None, "vcenter": None, "norm": Normalize(0, 5)}, + id="vbounds.norm", + ), + ] +) +def vbounds(request): + return request.param + + +def test_missing_values_categorical( + *, + request: pytest.FixtureRequest, + image_comparer, + adata, + plotfunc, + na_color, + na_in_legend, + legend_loc, + groupsfunc, +): + save_and_compare_images = partial(image_comparer, MISSING_VALUES_ROOT, tol=15) + + base_name = request.node.name + + # Passing through a dict so it's easier to use default values + kwargs = {} + kwargs["legend_loc"] = legend_loc + kwargs["groups"] = groupsfunc(adata.obs["label"]) + if na_color is not None: + kwargs["na_color"] = na_color + kwargs["na_in_legend"] = na_in_legend + + plotfunc(adata, color=["label", "label_missing"], **kwargs) + + save_and_compare_images(base_name) + + +def test_missing_values_continuous( + *, + request: pytest.FixtureRequest, + image_comparer, + adata, + plotfunc, + na_color, + vbounds, +): + save_and_compare_images = partial(image_comparer, MISSING_VALUES_ROOT, tol=15) + + base_name = request.node.name + + # Passing through a dict so it's easier to use default values + kwargs = {} + kwargs.update(vbounds) + if na_color is not None: + kwargs["na_color"] = na_color + + plotfunc(adata, color=["1", "1_missing"], **kwargs) + + save_and_compare_images(base_name) + + +def test_enumerated_palettes(request, adata, tmp_path, plotfunc): + base_name = request.node.name + + categories = adata.obs["label"].cat.categories + colors_rgb = dict(zip(categories, sns.color_palette(n_colors=12))) + + dict_pth = tmp_path / f"rgbdict_{base_name}.png" + list_pth = tmp_path / f"rgblist_{base_name}.png" + + # making a copy so colors aren't saved + plotfunc(adata.copy(), color="label", palette=colors_rgb) + plt.savefig(dict_pth, dpi=40) + plt.close() + plotfunc(adata.copy(), color="label", palette=[colors_rgb[c] for c in categories]) + plt.savefig(list_pth, dpi=40) + plt.close() + + check_images(dict_pth, list_pth, tol=15) + + +def test_dimension_broadcasting(adata, tmp_path, check_same_image): + with pytest.raises( + ValueError, + match=r"Could not broadcast together arguments with shapes: \[2, 3, 1\]", + ): + sc.pl.pca( + adata, color=["label", "1_missing"], dimensions=[(0, 1), (1, 2), (2, 3)] + ) + + dims_pth = tmp_path / "broadcast_dims.png" + color_pth = tmp_path / "broadcast_colors.png" + + sc.pl.pca(adata, color=["label", "label", "label"], dimensions=(2, 3), show=False) + plt.savefig(dims_pth, dpi=40) + plt.close() + sc.pl.pca(adata, color="label", dimensions=[(2, 3), (2, 3), (2, 3)], show=False) + plt.savefig(color_pth, dpi=40) + plt.close() + + check_same_image(dims_pth, color_pth, tol=5) + + +def test_marker_broadcasting(adata, tmp_path, check_same_image): + with pytest.raises( + ValueError, + match=r"Could not broadcast together arguments with shapes: \[2, 1, 3\]", + ): + sc.pl.pca(adata, color=["label", "1_missing"], marker=[".", "^", "x"]) + + dims_pth = tmp_path / "broadcast_markers.png" + color_pth = tmp_path / "broadcast_colors_for_markers.png" + + sc.pl.pca(adata, color=["label", "label", "label"], marker="^", show=False) + plt.savefig(dims_pth, dpi=40) + plt.close() + sc.pl.pca(adata, color="label", marker=["^", "^", "^"], show=False) + plt.savefig(color_pth, dpi=40) + plt.close() + + check_same_image(dims_pth, color_pth, tol=5) + + +def test_dimensions_same_as_components(adata, tmp_path, check_same_image): + adata = adata.copy() + adata.obs["mean"] = np.ravel(adata.X.mean(axis=1)) + + comp_pth = tmp_path / "components_plot.png" + dims_pth = tmp_path / "dimension_plot.png" + + # TODO: Deprecate components kwarg + # with pytest.warns(FutureWarning, match=r"components .* deprecated"): + sc.pl.pca( + adata, + color=["mean", "label"], + components=["1,2", "2,3"], + show=False, + ) + plt.savefig(comp_pth, dpi=40) + plt.close() + + sc.pl.pca( + adata, + color=["mean", "mean", "label", "label"], + dimensions=[(0, 1), (1, 2), (0, 1), (1, 2)], + show=False, + ) + plt.savefig(dims_pth, dpi=40) + plt.close() + + check_same_image(dims_pth, comp_pth, tol=5) + + +def test_embedding_colorbar_location(image_comparer): + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = pbmc3k_processed().raw.to_adata() + + sc.pl.pca(adata, color="LDHB", colorbar_loc=None) + + save_and_compare_images("no_colorbar") diff --git a/tests/test_plotting_embedded/test_spatial.py b/tests/test_plotting_embedded/test_spatial.py new file mode 100644 index 0000000000..873db68794 --- /dev/null +++ b/tests/test_plotting_embedded/test_spatial.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +from functools import partial +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pytest +from matplotlib.testing.compare import compare_images + +import scanpy as sc + +HERE: Path = Path(__file__).parent +ROOT = HERE.parent / "_images" +DATA_DIR = HERE.parent / "_data" + + +pytestmark = [ + pytest.mark.filterwarnings("ignore:Use `squidpy.*` instead:FutureWarning") +] + + +def check_images(pth1: Path, pth2: Path, *, tol: int) -> None: + result = compare_images(str(pth1), str(pth2), tol=tol) + assert result is None, result + + +def test_visium_circles(image_comparer): # standard visium data + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = sc.read_visium(DATA_DIR / "visium_data" / "1.0.0") + adata.obs = adata.obs.astype({"array_row": "str"}) + + sc.pl.spatial( + adata, + color="array_row", + groups=["24", "33"], + crop_coord=(100, 400, 400, 100), + alpha=0.5, + size=1.3, + show=False, + ) + + save_and_compare_images("spatial_visium") + + +def test_visium_default(image_comparer): # default values + from packaging.version import parse as parse_version + + if parse_version(mpl.__version__) < parse_version("3.7.0"): + pytest.xfail("Matplotlib 3.7.0+ required for this test") + + save_and_compare_images = partial(image_comparer, ROOT, tol=5) + + adata = sc.read_visium(DATA_DIR / "visium_data" / "1.0.0") + adata.obs = adata.obs.astype({"array_row": "str"}) + + # Points default to transparent if an image is included + sc.pl.spatial(adata, show=False) + + save_and_compare_images("spatial_visium_default") + + +def test_visium_empty_img_key(image_comparer): # visium coordinates but image empty + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = sc.read_visium(DATA_DIR / "visium_data" / "1.0.0") + adata.obs = adata.obs.astype({"array_row": "str"}) + + sc.pl.spatial(adata, img_key=None, color="array_row", show=False) + + save_and_compare_images("spatial_visium_empty_image") + + sc.pl.embedding(adata, basis="spatial", color="array_row", show=False) + save_and_compare_images("spatial_visium_embedding") + + +def test_spatial_general(image_comparer): # general coordinates + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = sc.read_visium(DATA_DIR / "visium_data" / "1.0.0") + adata.obs = adata.obs.astype({"array_row": "str"}) + spatial_metadata = adata.uns.pop( + "spatial" + ) # spatial data don't have imgs, so remove entry from uns + # Required argument for now + spot_size = list(spatial_metadata.values())[0]["scalefactors"][ + "spot_diameter_fullres" + ] + + sc.pl.spatial(adata, show=False, spot_size=spot_size) + save_and_compare_images("spatial_general_nocol") + + # category + sc.pl.spatial(adata, show=False, spot_size=spot_size, color="array_row") + save_and_compare_images("spatial_general_cat") + + # continuous + sc.pl.spatial(adata, show=False, spot_size=spot_size, color="array_col") + save_and_compare_images("spatial_general_cont") + + +def test_spatial_external_img(image_comparer): # external image + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = sc.read_visium(DATA_DIR / "visium_data" / "1.0.0") + adata.obs = adata.obs.astype({"array_row": "str"}) + + img = adata.uns["spatial"]["custom"]["images"]["hires"] + scalef = adata.uns["spatial"]["custom"]["scalefactors"]["tissue_hires_scalef"] + sc.pl.spatial( + adata, + color="array_row", + scale_factor=scalef, + img=img, + basis="spatial", + show=False, + ) + save_and_compare_images("spatial_external_img") + + +@pytest.fixture(scope="module") +def equivalent_spatial_plotters(adata): + no_spatial = adata.copy() + del no_spatial.uns["spatial"] + + img_key = "hires" + library_id = list(adata.uns["spatial"])[0] + spatial_data = adata.uns["spatial"][library_id] + img = spatial_data["images"][img_key] + scale_factor = spatial_data["scalefactors"][f"tissue_{img_key}_scalef"] + spot_size = spatial_data["scalefactors"]["spot_diameter_fullres"] + + orig_plotter = partial(sc.pl.spatial, adata, color="1", show=False) + removed_plotter = partial( + sc.pl.spatial, + no_spatial, + color="1", + img=img, + scale_factor=scale_factor, + spot_size=spot_size, + show=False, + ) + + return (orig_plotter, removed_plotter) + + +@pytest.fixture(scope="module") +def equivalent_spatial_plotters_no_img(equivalent_spatial_plotters): + orig, removed = equivalent_spatial_plotters + return (partial(orig, img_key=None), partial(removed, img=None, scale_factor=None)) + + +@pytest.fixture( + params=[ + pytest.param({"crop_coord": (50, 200, 0, 500)}, id="crop"), + pytest.param({"size": 0.5}, id="size:.5"), + pytest.param({"size": 2}, id="size:2"), + pytest.param({"spot_size": 5}, id="spotsize"), + pytest.param({"bw": True}, id="bw"), + # Shape of the image for particular fixture, should not be hardcoded like this + pytest.param({"img": np.ones((774, 1755, 4)), "scale_factor": 1.0}, id="img"), + pytest.param( + {"na_color": (0, 0, 0, 0), "color": "1_missing"}, id="na_color.transparent" + ), + pytest.param( + {"na_color": "lightgray", "color": "1_missing"}, id="na_color.lightgray" + ), + ] +) +def spatial_kwargs(request): + return request.param + + +def test_manual_equivalency(equivalent_spatial_plotters, tmp_path, spatial_kwargs): + """ + Tests that manually passing values to sc.pl.spatial is similar to automatic extraction. + """ + orig, removed = equivalent_spatial_plotters + + orig_pth = tmp_path / "orig.png" + removed_pth = tmp_path / "removed.png" + + orig(**spatial_kwargs) + plt.savefig(orig_pth, dpi=40) + plt.close() + removed(**spatial_kwargs) + plt.savefig(removed_pth, dpi=40) + plt.close() + + check_images(orig_pth, removed_pth, tol=1) + + +def test_manual_equivalency_no_img( + equivalent_spatial_plotters_no_img, tmp_path, spatial_kwargs +): + if "bw" in spatial_kwargs: + # Has no meaning when there is no image + pytest.skip() + orig, removed = equivalent_spatial_plotters_no_img + + orig_pth = tmp_path / "orig.png" + removed_pth = tmp_path / "removed.png" + + orig(**spatial_kwargs) + plt.savefig(orig_pth, dpi=40) + plt.close() + removed(**spatial_kwargs) + plt.savefig(removed_pth, dpi=40) + plt.close() + + check_images(orig_pth, removed_pth, tol=1) + + +def test_white_background_vs_no_img(adata, tmp_path, spatial_kwargs): + if {"bw", "img", "img_key", "na_color"}.intersection(spatial_kwargs): + # These arguments don't make sense for this check + pytest.skip() + + white_background = np.ones_like( + adata.uns["spatial"]["scanpy_img"]["images"]["hires"] + ) + white_pth = tmp_path / "white_background.png" + noimg_pth = tmp_path / "no_img.png" + + sc.pl.spatial( + adata, + color="2", + img=white_background, + scale_factor=1.0, + show=False, + **spatial_kwargs, + ) + plt.savefig(white_pth) + sc.pl.spatial(adata, color="2", img_key=None, show=False, **spatial_kwargs) + plt.savefig(noimg_pth) + + check_images(white_pth, noimg_pth, tol=1) + + +def test_spatial_na_color(adata, tmp_path): + """ + Check that na_color defaults to transparent when an image is present, light gray when not. + """ + white_background = np.ones_like( + adata.uns["spatial"]["scanpy_img"]["images"]["hires"] + ) + lightgray_pth = tmp_path / "lightgray.png" + transparent_pth = tmp_path / "transparent.png" + noimg_pth = tmp_path / "noimg.png" + whiteimg_pth = tmp_path / "whiteimg.png" + + def plot(pth, **kwargs): + sc.pl.spatial(adata, color="1_missing", show=False, **kwargs) + plt.savefig(pth, dpi=40) + plt.close() + + plot(lightgray_pth, na_color="lightgray", img_key=None) + plot(transparent_pth, na_color=(0.0, 0.0, 0.0, 0.0), img_key=None) + plot(noimg_pth, img_key=None) + plot(whiteimg_pth, img=white_background, scale_factor=1.0) + + check_images(lightgray_pth, noimg_pth, tol=1) + check_images(transparent_pth, whiteimg_pth, tol=1) + with pytest.raises(AssertionError): + check_images(lightgray_pth, transparent_pth, tol=1) diff --git a/tests/test_read_10x.py b/tests/test_read_10x.py index 7b31f6bddf..301a156bec 100644 --- a/tests/test_read_10x.py +++ b/tests/test_read_10x.py @@ -143,6 +143,7 @@ def visium_pth(request, tmp_path) -> Path: pytest.fail("add branch for new visium version") +@pytest.mark.filterwarnings("ignore:Use `squidpy.*` instead:FutureWarning") def test_read_visium_counts(visium_pth): """Test checking that read_visium reads the right genome""" spec_genome_v3 = sc.read_visium(visium_pth, genome="GRCh38") From 397d7036ed4fa8358ac552f7d8dc7b3c5ea5e93b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Dec 2024 14:19:20 +0100 Subject: [PATCH 06/24] Add sample probabilities (#3410) --- docs/release-notes/3410.feature.md | 1 + src/scanpy/get/_aggregated.py | 3 +- src/scanpy/get/get.py | 49 +++++++++++++----- src/scanpy/plotting/_tools/scatterplots.py | 3 +- src/scanpy/preprocessing/_scale.py | 2 +- src/scanpy/preprocessing/_simple.py | 16 +++++- src/scanpy/tools/_rank_genes_groups.py | 3 +- tests/test_preprocessing.py | 59 +++++++++++++++++----- 8 files changed, 102 insertions(+), 34 deletions(-) create mode 100644 docs/release-notes/3410.feature.md diff --git a/docs/release-notes/3410.feature.md b/docs/release-notes/3410.feature.md new file mode 100644 index 0000000000..d95ad201ba --- /dev/null +++ b/docs/release-notes/3410.feature.md @@ -0,0 +1 @@ +Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer` diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 13ca54b5c4..53a18bb47c 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -263,8 +263,7 @@ def aggregate( if axis is None: axis = 1 if varm else 0 axis, axis_name = _resolve_axis(axis) - if mask is not None: - mask = _check_mask(adata, mask, axis_name) + mask = _check_mask(adata, mask, axis_name) data = adata.X if sum(p is not None for p in [varm, obsm, layer]) > 1: raise TypeError("Please only provide one (or none) of varm, obsm, or layer") diff --git a/src/scanpy/get/get.py b/src/scanpy/get/get.py index f3172ed45e..c36ddde8f8 100644 --- a/src/scanpy/get/get.py +++ b/src/scanpy/get/get.py @@ -2,11 +2,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar import numpy as np import pandas as pd from anndata import AnnData +from numpy.typing import NDArray from packaging.version import Version from scipy.sparse import spmatrix @@ -16,7 +17,11 @@ from anndata._core.sparse_dataset import BaseCompressedSparseDataset from anndata._core.views import ArrayView - from numpy.typing import NDArray + from scipy.sparse import csc_matrix, csr_matrix + + from .._compat import DaskArray + + CSMatrix = csr_matrix | csc_matrix # -------------------------------------------------------------------------------- # Plotting data helpers @@ -485,11 +490,16 @@ def _set_obs_rep( raise AssertionError(msg) +M = TypeVar("M", bound=NDArray[np.bool_] | NDArray[np.floating] | pd.Series | None) + + def _check_mask( - data: AnnData | np.ndarray, - mask: NDArray[np.bool_] | str, + data: AnnData | np.ndarray | CSMatrix | DaskArray, + mask: str | M, dim: Literal["obs", "var"], -) -> NDArray[np.bool_]: # Could also be a series, but should be one or the other + *, + allow_probabilities: bool = False, +) -> M: # Could also be a series, but should be one or the other """ Validate mask argument Params @@ -497,30 +507,45 @@ def _check_mask( data Annotated data matrix or numpy array. mask - The mask. Either an appropriatley sized boolean array, or name of a column which will be used to mask. + Mask (or probabilities if `allow_probabilities=True`). + Either an appropriatley sized array, or name of a column. dim The dimension being masked. + allow_probabilities + Whether to allow probabilities as `mask` """ + if mask is None: + return mask + desc = "mask/probabilities" if allow_probabilities else "mask" + if isinstance(mask, str): if not isinstance(data, AnnData): - msg = "Cannot refer to mask with string without providing anndata object as argument" + msg = f"Cannot refer to {desc} with string without providing anndata object as argument" raise ValueError(msg) annot: pd.DataFrame = getattr(data, dim) if mask not in annot.columns: msg = ( f"Did not find `adata.{dim}[{mask!r}]`. " - f"Either add the mask first to `adata.{dim}`" - "or consider using the mask argument with a boolean array." + f"Either add the {desc} first to `adata.{dim}`" + f"or consider using the {desc} argument with an array." ) raise ValueError(msg) mask_array = annot[mask].to_numpy() else: if len(mask) != data.shape[0 if dim == "obs" else 1]: - raise ValueError("The shape of the mask do not match the data.") + msg = f"The shape of the {desc} do not match the data." + raise ValueError(msg) mask_array = mask - if not pd.api.types.is_bool_dtype(mask_array.dtype): - raise ValueError("Mask array must be boolean.") + is_bool = pd.api.types.is_bool_dtype(mask_array.dtype) + if not allow_probabilities and not is_bool: + msg = "Mask array must be boolean." + raise ValueError(msg) + elif allow_probabilities and not ( + is_bool or pd.api.types.is_float_dtype(mask_array.dtype) + ): + msg = f"{desc} array must be boolean or floating point." + raise ValueError(msg) return mask_array diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index e2564eb17f..b54897678f 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -150,8 +150,7 @@ def embedding( # Checking the mask format and if used together with groups if groups is not None and mask_obs is not None: raise ValueError("Groups and mask arguments are incompatible.") - if mask_obs is not None: - mask_obs = _check_mask(adata, mask_obs, "obs") + mask_obs = _check_mask(adata, mask_obs, "obs") # Figure out if we're using raw if use_raw is None: diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index d7123d5f65..bac08f246b 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -164,8 +164,8 @@ def scale_array( ): if copy: X = X.copy() + mask_obs = _check_mask(X, mask_obs, "obs") if mask_obs is not None: - mask_obs = _check_mask(X, mask_obs, "obs") scale_rv = scale_array( X[mask_obs, :], zero_center=zero_center, diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 29c267c3f4..821615676a 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -30,7 +30,7 @@ sanitize_anndata, view_to_actual, ) -from ..get import _get_obs_rep, _set_obs_rep +from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray from ._utils import _to_dense @@ -838,6 +838,7 @@ def sample( copy: Literal[False] = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> None: ... @overload def sample( @@ -849,6 +850,7 @@ def sample( copy: Literal[True], replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> AnnData: ... @overload def sample( @@ -860,6 +862,7 @@ def sample( copy: bool = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> tuple[A, NDArray[np.int64]]: ... def sample( data: AnnData | np.ndarray | CSMatrix | DaskArray, @@ -870,6 +873,7 @@ def sample( copy: bool = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", + p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]: """\ Sample observations or variables with or without replacement. @@ -881,6 +885,7 @@ def sample( Rows correspond to cells and columns to genes. fraction Sample to this `fraction` of the number of observations or variables. + (All of them, even if there are `0`s/`False`s in `p`.) This can be larger than 1.0, if `replace=True`. See `axis` and `replace`. n @@ -894,6 +899,10 @@ def sample( If True, samples are drawn with replacement. axis Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1). + p + Drawing probabilities (floats) or mask (bools). + Either an `axis`-sized array, or the name of a column. + If `p` is an array of probabilities, it must sum to 1. Returns ------- @@ -910,6 +919,9 @@ def sample( msg = "Inplace sampling (`copy=False`) is not implemented for backed objects." raise NotImplementedError(msg) axis, axis_name = _resolve_axis(axis) + p = _check_mask(data, p, dim=axis_name, allow_probabilities=True) + if p is not None and p.dtype == bool: + p = p.astype(np.float64) / p.sum() old_n = data.shape[axis] match (fraction, n): case (None, None): @@ -933,7 +945,7 @@ def sample( # actually do subsampling rng = np.random.default_rng(rng) - indices = rng.choice(old_n, size=n, replace=replace) + indices = rng.choice(old_n, size=n, replace=replace, p=p) # overload 1: inplace AnnData subset if not copy and isinstance(data, AnnData): diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index aa4428dad1..2c214fcfdd 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -594,8 +594,7 @@ def rank_genes_groups( >>> # to visualize the results >>> sc.pl.rank_genes_groups(adata) """ - if mask_var is not None: - mask_var = _check_mask(adata, mask_var, "var") + mask_var = _check_mask(adata, mask_var, "var") if use_raw is None: use_raw = adata.raw is not None diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 36283e7ed0..6282c5ccf4 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -29,6 +29,8 @@ from collections.abc import Callable from typing import Any, Literal + from numpy.typing import NDArray + CSMatrix = sp.csc_matrix | sp.csr_matrix @@ -144,31 +146,55 @@ def test_normalize_per_cell(): assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist() +def _random_probs(n: int, frac_zero: float) -> NDArray[np.float64]: + """ + Generate a random probability distribution of `n` values between 0 and 1. + """ + probs = np.random.randint(0, 10000, n).astype(np.float64) + probs[probs < np.quantile(probs, frac_zero)] = 0 + probs /= probs.sum() + np.testing.assert_almost_equal(probs.sum(), 1) + return probs + + @pytest.mark.parametrize("array_type", ARRAY_TYPES) @pytest.mark.parametrize("which", ["copy", "inplace", "array"]) @pytest.mark.parametrize( - ("axis", "fraction", "n", "replace", "expected"), + ("axis", "f_or_n", "replace"), + [ + pytest.param(0, 40, False, id="obs-40-no_replace"), + pytest.param(0, 0.1, False, id="obs-0.1-no_replace"), + pytest.param(0, 201, True, id="obs-201-replace"), + pytest.param(0, 1, True, id="obs-1-replace"), + pytest.param(1, 10, False, id="var-10-no_replace"), + pytest.param(1, 11, True, id="var-11-replace"), + pytest.param(1, 2.0, True, id="var-2.0-replace"), + ], +) +@pytest.mark.parametrize( + "ps", [ - pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"), - pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"), - pytest.param(0, None, 201, True, 201, id="obs-201-replace"), - pytest.param(0, None, 1, True, 1, id="obs-1-replace"), - pytest.param(1, None, 10, False, 10, id="var-10-no_replace"), - pytest.param(1, None, 11, True, 11, id="var-11-replace"), - pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"), + dict(obs=None, var=None), + dict(obs=np.tile([True, False], 100), var=np.tile([True, False], 5)), + dict(obs=_random_probs(200, 0.3), var=_random_probs(10, 0.7)), ], + ids=["all", "mask", "p"], ) def test_sample( *, + request: pytest.FixtureRequest, array_type: Callable[[np.ndarray], np.ndarray | CSMatrix], which: Literal["copy", "inplace", "array"], axis: Literal[0, 1], - fraction: float | None, - n: int | None, + f_or_n: float | int, # noqa: PYI041 replace: bool, - expected: int, + ps: dict[Literal["obs", "var"], NDArray[np.bool_] | None], ): adata = AnnData(array_type(np.ones((200, 10)))) + p = ps["obs" if axis == 0 else "var"] + expected = int(adata.shape[axis] * f_or_n) if isinstance(f_or_n, float) else f_or_n + if p is not None and not replace and expected > (n_possible := (p != 0).sum()): + request.applymarker(pytest.xfail(f"Can’t draw {expected} out of {n_possible}")) # ignoring this warning declaratively is a pain so do it here if find_spec("dask"): @@ -182,12 +208,13 @@ def test_sample( ) rv = sc.pp.sample( adata.X if which == "array" else adata, - fraction, - n=n, + f_or_n if isinstance(f_or_n, float) else None, + n=f_or_n if isinstance(f_or_n, int) else None, replace=replace, axis=axis, # `copy` only effects AnnData inputs copy=dict(copy=True, inplace=False, array=False)[which], + p=p, ) match which: @@ -232,6 +259,12 @@ def test_sample( r"`fraction=-0\.3` needs to be nonnegative", id="frac<0", ), + pytest.param( + dict(n=3, p=np.ones(200, dtype=np.int32)), + ValueError, + r"mask/probabilities array must be boolean or floating point", + id="type(p)", + ), ], ) def test_sample_error(args: dict[str, Any], exc: type[Exception], pattern: str): From 465806b30ed908d9708b8faec866a15e00b923c4 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 20 Dec 2024 14:45:40 +0100 Subject: [PATCH 07/24] (chore): generate 1.11.0 release notes (#3412) --- ci/scripts/towncrier_automation.py | 6 +++- docs/release-notes/1.11.0.md | 38 ++++++++++++++++++++++++++ docs/release-notes/2921.feature.md | 1 - docs/release-notes/3155.feature.md | 1 - docs/release-notes/3180.feature.md | 1 - docs/release-notes/3184.feature.md | 1 - docs/release-notes/3263.feature.md | 1 - docs/release-notes/3267.feature.md | 1 - docs/release-notes/3284.performance.md | 1 - docs/release-notes/3296.feature.md | 1 - docs/release-notes/3307.feature.md | 1 - docs/release-notes/3324.feature.md | 1 - docs/release-notes/3335.feature.md | 1 - docs/release-notes/3362.doc.md | 1 - docs/release-notes/3380.bugfix.md | 1 - docs/release-notes/3384.feature.md | 1 - docs/release-notes/3393.bugfix.md | 1 - docs/release-notes/3407.misc.md | 7 ----- docs/release-notes/3410.feature.md | 1 - docs/release-notes/943.feature.md | 1 - 20 files changed, 43 insertions(+), 25 deletions(-) create mode 100644 docs/release-notes/1.11.0.md delete mode 100644 docs/release-notes/2921.feature.md delete mode 100644 docs/release-notes/3155.feature.md delete mode 100644 docs/release-notes/3180.feature.md delete mode 100644 docs/release-notes/3184.feature.md delete mode 100644 docs/release-notes/3263.feature.md delete mode 100644 docs/release-notes/3267.feature.md delete mode 100644 docs/release-notes/3284.performance.md delete mode 100644 docs/release-notes/3296.feature.md delete mode 100644 docs/release-notes/3307.feature.md delete mode 100644 docs/release-notes/3324.feature.md delete mode 100644 docs/release-notes/3335.feature.md delete mode 100644 docs/release-notes/3362.doc.md delete mode 100644 docs/release-notes/3380.bugfix.md delete mode 100644 docs/release-notes/3384.feature.md delete mode 100644 docs/release-notes/3393.bugfix.md delete mode 100644 docs/release-notes/3407.misc.md delete mode 100644 docs/release-notes/3410.feature.md delete mode 100644 docs/release-notes/943.feature.md diff --git a/ci/scripts/towncrier_automation.py b/ci/scripts/towncrier_automation.py index c532883036..10a8b0c9dc 100755 --- a/ci/scripts/towncrier_automation.py +++ b/ci/scripts/towncrier_automation.py @@ -92,7 +92,11 @@ def main(argv: Sequence[str] | None = None) -> None: f"--base={base_branch}", f"--title={pr_title}", f"--body={pr_description}", - *(["--label=no milestone"] if base_branch == "main" else []), + *( + ["--label=no milestone", "--label=Development Process 🚀"] + if base_branch == "main" + else [] + ), *(["--dry-run"] if args.dry_run else []), ], check=True, diff --git a/docs/release-notes/1.11.0.md b/docs/release-notes/1.11.0.md new file mode 100644 index 0000000000..8c51fe8a37 --- /dev/null +++ b/docs/release-notes/1.11.0.md @@ -0,0 +1,38 @@ +(v1.11.0)= +### 1.11.0 {small}`2024-12-20` + +### Features + +- {func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan & P Angerer` ({pr}`943`) +- Add `layer` argument to {func}`scanpy.tl.score_genes` and {func}`scanpy.tl.score_genes_cell_cycle` {smaller}`L Zappia` ({pr}`2921`) +- Prevent `raw` conflict with `layer` in {func}`~scanpy.tl.score_genes` {smaller}`S Dicks` ({pr}`3155`) +- Add support for `median` as an aggregation function to the `Aggregation` class in `scanpy.get._aggregated.py`. This allows for median-based aggregation of data (e.g., pseudobulk), complementing existing methods like mean- and sum-based aggregation {smaller}`M Dehkordi (Farhad)` ({pr}`3180`) +- Add `key_added` argument to {func}`~scanpy.pp.pca`, {func}`~scanpy.tl.tsne` and {func}`~scanpy.tl.umap` {smaller}`P Angerer` ({pr}`3184`) +- Support running {func}`scanpy.pp.pca` on sparse Dask arrays with the `'covariance_eigh'` solver {smaller}`P Angerer` ({pr}`3263`) +- Use upstreamed {class}`~sklearn.decomposition.PCA` implementation for {class}`~scipy.sparse.csr_array` and {class}`~scipy.sparse.csr_matrix` (see {ref}`sklearn:changes_1_4`) {smaller}`P Angerer` ({pr}`3267`) +- Add explicit support to {func}`scanpy.pp.pca` for `svd_solver='covariance_eigh'` {smaller}`P Angerer` ({pr}`3296`) +- Add support {class}`dask.array.Array` to {func}`scanpy.pp.calculate_qc_metrics` {smaller}`I Gold` ({pr}`3307`) +- Support `layer` parameter in {func}`scanpy.pl.highest_expr_genes` {smaller}`P Angerer` ({pr}`3324`) +- Run numba functions single-threaded when called from inside of a ThreadPool {smaller}`P Angerer` ({pr}`3335`) +- Switch {func}`~scanpy.logging.print_header` and {func}`~scanpy.logging.print_versions` to {mod}`session_info2` {smaller}`P Angerer` ({pr}`3384`) +- Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer` ({pr}`3410`) + +### Performance + +- Speed up {func}`~scanpy.pp.regress_out` {smaller}`P Ashish, P Angerer & S Dicks` ({pr}`3284`) + +### Documentation + +- Improve {func}`~scanpy.external.pp.harmony_integrate` docs {smaller}`D Kühl` ({pr}`3362`) +- Raise {exc}`FutureWarning` when calling deprecated {mod}`scanpy.pp` functions {smaller}`P Angerer` ({pr}`3380`) +- | Deprecate … | in favor of … | + | --- | --- | + | {func}`scanpy.read_visium` | {func}`squidpy.read.visium` | + | {func}`scanpy.datasets.visium_sge` | {func}`squidpy.datasets.visium` | + | {func}`scanpy.pl.spatial` | {func}`squidpy.pl.spatial_scatter` | + + {smaller}`P Angerer` ({pr}`3407`) + +### Bug fixes + +- Upper-bound {mod}`sklearn` `<1.6.0` due to {issue}`dask/dask-ml#1002` {smaller}`Ilan Gold` ({pr}`3393`) diff --git a/docs/release-notes/2921.feature.md b/docs/release-notes/2921.feature.md deleted file mode 100644 index e3c964abb2..0000000000 --- a/docs/release-notes/2921.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add `layer` argument to {func}`scanpy.tl.score_genes` and {func}`scanpy.tl.score_genes_cell_cycle` {smaller}`L Zappia` diff --git a/docs/release-notes/3155.feature.md b/docs/release-notes/3155.feature.md deleted file mode 100644 index 770c504348..0000000000 --- a/docs/release-notes/3155.feature.md +++ /dev/null @@ -1 +0,0 @@ -Prevent `raw` conflict with `layer` in {func}`~scanpy.tl.score_genes` {smaller}`S Dicks` diff --git a/docs/release-notes/3180.feature.md b/docs/release-notes/3180.feature.md deleted file mode 100644 index ab73dfe18e..0000000000 --- a/docs/release-notes/3180.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add support for `median` as an aggregation function to the `Aggregation` class in `scanpy.get._aggregated.py`. This allows for median-based aggregation of data (e.g., pseudobulk), complementing existing methods like mean- and sum-based aggregation {smaller}`M Dehkordi (Farhad)` diff --git a/docs/release-notes/3184.feature.md b/docs/release-notes/3184.feature.md deleted file mode 100644 index 3cc976b141..0000000000 --- a/docs/release-notes/3184.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add `key_added` argument to {func}`~scanpy.pp.pca`, {func}`~scanpy.tl.tsne` and {func}`~scanpy.tl.umap` {smaller}`P Angerer` diff --git a/docs/release-notes/3263.feature.md b/docs/release-notes/3263.feature.md deleted file mode 100644 index 8e924e1799..0000000000 --- a/docs/release-notes/3263.feature.md +++ /dev/null @@ -1 +0,0 @@ -Support running {func}`scanpy.pp.pca` on sparse Dask arrays with the `'covariance_eigh'` solver {smaller}`P Angerer` diff --git a/docs/release-notes/3267.feature.md b/docs/release-notes/3267.feature.md deleted file mode 100644 index 6ea7fb20a2..0000000000 --- a/docs/release-notes/3267.feature.md +++ /dev/null @@ -1 +0,0 @@ -Use upstreamed {class}`~sklearn.decomposition.PCA` implementation for {class}`~scipy.sparse.csr_array` and {class}`~scipy.sparse.csr_matrix` (see {ref}`sklearn:changes_1_4`) {smaller}`P Angerer` diff --git a/docs/release-notes/3284.performance.md b/docs/release-notes/3284.performance.md deleted file mode 100644 index 31c95245ff..0000000000 --- a/docs/release-notes/3284.performance.md +++ /dev/null @@ -1 +0,0 @@ -* Speed up {func}`~scanpy.pp.regress_out` {smaller}`P Ashish, P Angerer & S Dicks` diff --git a/docs/release-notes/3296.feature.md b/docs/release-notes/3296.feature.md deleted file mode 100644 index 74b89945dd..0000000000 --- a/docs/release-notes/3296.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add explicit support to {func}`scanpy.pp.pca` for `svd_solver='covariance_eigh'` {smaller}`P Angerer` diff --git a/docs/release-notes/3307.feature.md b/docs/release-notes/3307.feature.md deleted file mode 100644 index 1505befb40..0000000000 --- a/docs/release-notes/3307.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add support {class}`dask.array.Array` to {func}`scanpy.pp.calculate_qc_metrics` {smaller}`I Gold` diff --git a/docs/release-notes/3324.feature.md b/docs/release-notes/3324.feature.md deleted file mode 100644 index 03d14dceb6..0000000000 --- a/docs/release-notes/3324.feature.md +++ /dev/null @@ -1 +0,0 @@ -Support `layer` parameter in {func}`scanpy.pl.highest_expr_genes` {smaller}`P Angerer` diff --git a/docs/release-notes/3335.feature.md b/docs/release-notes/3335.feature.md deleted file mode 100644 index 77a1723a8e..0000000000 --- a/docs/release-notes/3335.feature.md +++ /dev/null @@ -1 +0,0 @@ -Run numba functions single-threaded when called from inside of a ThreadPool {smaller}`P Angerer` diff --git a/docs/release-notes/3362.doc.md b/docs/release-notes/3362.doc.md deleted file mode 100644 index 1dae77b3e2..0000000000 --- a/docs/release-notes/3362.doc.md +++ /dev/null @@ -1 +0,0 @@ -Improve {func}`~scanpy.external.pp.harmony_integrate` docs {smaller}`D Kühl` diff --git a/docs/release-notes/3380.bugfix.md b/docs/release-notes/3380.bugfix.md deleted file mode 100644 index 633ce346af..0000000000 --- a/docs/release-notes/3380.bugfix.md +++ /dev/null @@ -1 +0,0 @@ -Raise {exc}`FutureWarning` when calling deprecated {mod}`scanpy.pp` functions {smaller}`P Angerer` diff --git a/docs/release-notes/3384.feature.md b/docs/release-notes/3384.feature.md deleted file mode 100644 index 755af9a8a3..0000000000 --- a/docs/release-notes/3384.feature.md +++ /dev/null @@ -1 +0,0 @@ -Switch {func}`~scanpy.logging.print_header` and {func}`~scanpy.logging.print_versions` to {mod}`session_info2` {smaller}`P Angerer` diff --git a/docs/release-notes/3393.bugfix.md b/docs/release-notes/3393.bugfix.md deleted file mode 100644 index 22af00f124..0000000000 --- a/docs/release-notes/3393.bugfix.md +++ /dev/null @@ -1 +0,0 @@ -Upper-bound {mod}`sklearn` `<1.6.0` due to {issue}`dask/dask-ml#1002` {smaller}`Ilan Gold` diff --git a/docs/release-notes/3407.misc.md b/docs/release-notes/3407.misc.md deleted file mode 100644 index 4670de6fd4..0000000000 --- a/docs/release-notes/3407.misc.md +++ /dev/null @@ -1,7 +0,0 @@ -| Deprecate … | in favor of … | -| --- | --- | -| {func}`scanpy.read_visium` | {func}`squidpy.read.visium` | -| {func}`scanpy.datasets.visium_sge` | {func}`squidpy.datasets.visium` | -| {func}`scanpy.pl.spatial` | {func}`squidpy.pl.spatial_scatter` | - -{smaller}`P Angerer` diff --git a/docs/release-notes/3410.feature.md b/docs/release-notes/3410.feature.md deleted file mode 100644 index d95ad201ba..0000000000 --- a/docs/release-notes/3410.feature.md +++ /dev/null @@ -1 +0,0 @@ -Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer` diff --git a/docs/release-notes/943.feature.md b/docs/release-notes/943.feature.md deleted file mode 100644 index 4f5474d762..0000000000 --- a/docs/release-notes/943.feature.md +++ /dev/null @@ -1 +0,0 @@ -{func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan` & {smaller}`P Angerer` From 5654389f0abee562bcdfe8f4f2ce24a586007dd8 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Fri, 20 Dec 2024 14:51:48 +0100 Subject: [PATCH 08/24] =?UTF-8?q?Note=20that=20it=E2=80=99s=20an=20rc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/release-notes/1.11.0.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/1.11.0.md b/docs/release-notes/1.11.0.md index 8c51fe8a37..c7258ea271 100644 --- a/docs/release-notes/1.11.0.md +++ b/docs/release-notes/1.11.0.md @@ -1,5 +1,5 @@ (v1.11.0)= -### 1.11.0 {small}`2024-12-20` +### 1.11.0rc1 {small}`2024-12-20` ### Features From 09c88b16fa1381dd28174132005b556aba2936c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 09:53:46 +0100 Subject: [PATCH 09/24] [pre-commit.ci] pre-commit autoupdate (#3404) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c91285096..e891103226 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.2 + rev: v0.8.6 hooks: - id: ruff types_or: [python, pyi, jupyter] From 5a0cc5200edb753f8498519eefdf1fef86a1dd0e Mon Sep 17 00:00:00 2001 From: Dina Kazemi <50038942+dinakazemi@users.noreply.github.com> Date: Thu, 9 Jan 2025 21:10:39 +1100 Subject: [PATCH 10/24] Fix Markdown syntax for preprocessing docs page (#3418) Co-authored-by: Phil Schaf --- docs/api/preprocessing.md | 2 +- docs/api/tools.md | 2 ++ docs/release-notes/3418.doc.md | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 docs/release-notes/3418.doc.md diff --git a/docs/api/preprocessing.md b/docs/api/preprocessing.md index 36e732a6dc..1834d934a4 100644 --- a/docs/api/preprocessing.md +++ b/docs/api/preprocessing.md @@ -49,7 +49,7 @@ For visual quality control, see {func}`~scanpy.pl.highest_expr_genes` and ### Batch effect correction -Also see [Data integration]. Note that a simple batch correction method is available via {func}`pp.regress_out`. Checkout {mod}`scanpy.external` for more. +Also see {ref}`data-integration`. Note that a simple batch correction method is available via {func}`pp.regress_out`. Checkout {mod}`scanpy.external` for more. ```{eval-rst} .. autosummary:: diff --git a/docs/api/tools.md b/docs/api/tools.md index 1d51559e5a..13d82b46c7 100644 --- a/docs/api/tools.md +++ b/docs/api/tools.md @@ -48,6 +48,8 @@ Compute densities on embeddings. tl.paga ``` +(data-integration)= + ### Data integration ```{eval-rst} diff --git a/docs/release-notes/3418.doc.md b/docs/release-notes/3418.doc.md new file mode 100644 index 0000000000..d46304bf59 --- /dev/null +++ b/docs/release-notes/3418.doc.md @@ -0,0 +1 @@ +Fix reference in {mod}`scanpy.pp` page {smaller}`D Kazemi` From 97681d430f05ca3f0f8737acd8adfb5f3c22b326 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 9 Jan 2025 15:34:47 +0100 Subject: [PATCH 11/24] Doc fixes for 1.11 (#3415) --- docs/release-notes/1.11.0.md | 8 ++++---- src/scanpy/preprocessing/_simple.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/release-notes/1.11.0.md b/docs/release-notes/1.11.0.md index c7258ea271..a41103e0ec 100644 --- a/docs/release-notes/1.11.0.md +++ b/docs/release-notes/1.11.0.md @@ -6,14 +6,14 @@ - {func}`~scanpy.pp.sample` supports both upsampling and downsampling of observations and variables. {func}`~scanpy.pp.subsample` is now deprecated. {smaller}`G Eraslan & P Angerer` ({pr}`943`) - Add `layer` argument to {func}`scanpy.tl.score_genes` and {func}`scanpy.tl.score_genes_cell_cycle` {smaller}`L Zappia` ({pr}`2921`) - Prevent `raw` conflict with `layer` in {func}`~scanpy.tl.score_genes` {smaller}`S Dicks` ({pr}`3155`) -- Add support for `median` as an aggregation function to the `Aggregation` class in `scanpy.get._aggregated.py`. This allows for median-based aggregation of data (e.g., pseudobulk), complementing existing methods like mean- and sum-based aggregation {smaller}`M Dehkordi (Farhad)` ({pr}`3180`) +- Add support for `median` as an aggregation function to {func}`~scanpy.get.aggregate`. This allows for median-based aggregation of data (e.g., pseudobulk), complementing existing methods like mean- and sum-based aggregation {smaller}`M Dehkordi (Farhad)` ({pr}`3180`) - Add `key_added` argument to {func}`~scanpy.pp.pca`, {func}`~scanpy.tl.tsne` and {func}`~scanpy.tl.umap` {smaller}`P Angerer` ({pr}`3184`) - Support running {func}`scanpy.pp.pca` on sparse Dask arrays with the `'covariance_eigh'` solver {smaller}`P Angerer` ({pr}`3263`) -- Use upstreamed {class}`~sklearn.decomposition.PCA` implementation for {class}`~scipy.sparse.csr_array` and {class}`~scipy.sparse.csr_matrix` (see {ref}`sklearn:changes_1_4`) {smaller}`P Angerer` ({pr}`3267`) +- Use upstreamed {class}`~sklearn.decomposition.PCA` implementation for {class}`~scipy.sparse.csr_array` and {class}`~scipy.sparse.csr_matrix` (see scikit-learn {ref}`sklearn:changes_1_4`) {smaller}`P Angerer` ({pr}`3267`) - Add explicit support to {func}`scanpy.pp.pca` for `svd_solver='covariance_eigh'` {smaller}`P Angerer` ({pr}`3296`) -- Add support {class}`dask.array.Array` to {func}`scanpy.pp.calculate_qc_metrics` {smaller}`I Gold` ({pr}`3307`) +- Add support for {class}`dask.array.Array` to {func}`scanpy.pp.calculate_qc_metrics` {smaller}`I Gold` ({pr}`3307`) - Support `layer` parameter in {func}`scanpy.pl.highest_expr_genes` {smaller}`P Angerer` ({pr}`3324`) -- Run numba functions single-threaded when called from inside of a ThreadPool {smaller}`P Angerer` ({pr}`3335`) +- Run numba functions single-threaded when called from inside of a {class}`~multiprocessing.pool.ThreadPool` {smaller}`P Angerer` ({pr}`3335`) - Switch {func}`~scanpy.logging.print_header` and {func}`~scanpy.logging.print_versions` to {mod}`session_info2` {smaller}`P Angerer` ({pr}`3384`) - Add sampling probabilities/mask parameter `p` to {func}`~scanpy.pp.sample` {smaller}`P Angerer` ({pr}`3410`) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 821615676a..ac68edd376 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -885,7 +885,7 @@ def sample( Rows correspond to cells and columns to genes. fraction Sample to this `fraction` of the number of observations or variables. - (All of them, even if there are `0`s/`False`s in `p`.) + (All of them, even if there are `0`\\ s/`False`\\ s in `p`.) This can be larger than 1.0, if `replace=True`. See `axis` and `replace`. n From e94ae5fb6f218174b0cba9b6fa8cd950caffbf5a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 9 Jan 2025 15:43:16 +0100 Subject: [PATCH 12/24] Mention towncrier in contribution docs (#3427) --- docs/dev/code.md | 1 + docs/dev/documentation.md | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/dev/code.md b/docs/dev/code.md index 1e9d295725..3ca393c8f7 100644 --- a/docs/dev/code.md +++ b/docs/dev/code.md @@ -9,6 +9,7 @@ 5. {ref}`Make sure all tests are passing ` 6. {ref}`Build and visually check any changed documentation ` 7. {ref}`Open a PR back to the main repository ` +8. {ref}`Add a release note to your PR ` ## Code style diff --git a/docs/dev/documentation.md b/docs/dev/documentation.md index d9c3f6e034..dcad9533ed 100644 --- a/docs/dev/documentation.md +++ b/docs/dev/documentation.md @@ -12,10 +12,12 @@ Sometimes these caches are not invalidated when you've updated the docs. If docs are not updating the way you expect, first try "force reloading" your browser page – e.g. reload the page without using the cache. Next, if problems persist, clear the sphinx cache (`hatch run docs:clean`) and try building them again. +(adding-to-the-docs)= + ## Adding to the docs For any user-visible changes, please make sure a note has been added to the release notes using [`hatch run towncrier:create`][towncrier create]. -We recommend waiting on this until your PR is close to done since this can often causes merge conflicts. +When asked for “Issue number (`+` if none)”, enter the *PR number* instead. Once you've added a new function to the documentation, you'll need to make sure there is a link somewhere in the documentation site pointing to it. This should be added to `docs/api.md` under a relevant heading. From 4ac2a3fffcadf8ba9e874650bc7e65438cbd7e27 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 9 Jan 2025 15:46:02 +0100 Subject: [PATCH 13/24] Fix wilcoxon for >10M cells (#3426) --- docs/release-notes/3426.bugfix.md | 1 + src/scanpy/tools/_rank_genes_groups.py | 11 +++++------ tests/test_rank_genes_groups.py | 7 +++++++ 3 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 docs/release-notes/3426.bugfix.md diff --git a/docs/release-notes/3426.bugfix.md b/docs/release-notes/3426.bugfix.md new file mode 100644 index 0000000000..4565f1ee35 --- /dev/null +++ b/docs/release-notes/3426.bugfix.md @@ -0,0 +1 @@ +Fix {func}`~scanpy.tl.rank_genes_groups` compatibility with data >10M cells {smaller}`P Angerer` diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 2c214fcfdd..cafb78c6f1 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -2,7 +2,6 @@ from __future__ import annotations -from math import floor from typing import TYPE_CHECKING, Literal import numpy as np @@ -32,6 +31,8 @@ # Used with get_literal_vals _Method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] +_CONST_MAX_SIZE = 10000000 + def _select_top_n(scores: NDArray, n_top: int): n_from = scores.shape[0] @@ -47,9 +48,7 @@ def _ranks( X: np.ndarray | sparse.csr_matrix | sparse.csc_matrix, mask_obs: NDArray[np.bool_] | None = None, mask_obs_rest: NDArray[np.bool_] | None = None, -): - CONST_MAX_SIZE = 10000000 - +) -> Generator[tuple[pd.DataFrame, int, int], None, None]: n_genes = X.shape[1] if issparse(X): @@ -71,7 +70,7 @@ def _ranks( get_chunk = lambda X, left, right: adapt(X[:, left:right]) # Calculate chunk frames - max_chunk = floor(CONST_MAX_SIZE / n_cells) + max_chunk = max(_CONST_MAX_SIZE // n_cells, 1) for left in range(0, n_genes, max_chunk): right = min(left + max_chunk, n_genes) @@ -81,7 +80,7 @@ def _ranks( yield ranks, left, right -def _tiecorrect(ranks): +def _tiecorrect(ranks: pd.DataFrame) -> np.float64: size = np.float64(ranks.shape[0]) if size < 2: return np.repeat(ranks.shape[1], 1.0) diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index a36e6b14f1..788c7e705d 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -307,6 +307,13 @@ def test_wilcoxon_tie_correction(reference): np.testing.assert_allclose(test_obj.stats[groups[0]]["pvals"], pvals) +def test_wilcoxon_huge_data(monkeypatch): + max_size = 300 + adata = pbmc68k_reduced() + monkeypatch.setattr(sc.tl._rank_genes_groups, "_CONST_MAX_SIZE", max_size) + rank_genes_groups(adata, groupby="bulk_labels", method="wilcoxon") + + @pytest.mark.parametrize( ("n_genes_add", "n_genes_out_add"), [pytest.param(0, 0, id="equal"), pytest.param(2, 1, id="more")], From e71dc55e988f9dc46531b1ecb6754bacdef06cf3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 9 Jan 2025 16:01:16 +0100 Subject: [PATCH 14/24] Update author/maintainer metadata (#3413) --- docs/contributors.md | 19 +++++++++++-------- pyproject.toml | 10 ++++++---- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/contributors.md b/docs/contributors.md index a9b2d79e3c..2e9c62d255 100644 --- a/docs/contributors.md +++ b/docs/contributors.md @@ -1,22 +1,25 @@ # Contributors [anndata graph](https://github.com/scverse/anndata/graphs/contributors>) | [scanpy graph](https://github.com/scverse/scanpy/graphs/contributors)| ☀ = maintainer + ## Current developers -- [Isaac Virshup](https://github.com/ivirshup), lead developer since 2019 ☀ -- [Gökcen Eraslan](https://twitter.com/gokcen), developer, diverse contributions ☀ -- [Sergei Rybakov](https://github.com/Koncopd), developer, diverse contributions ☀ -- [Fidel Ramirez](https://github.com/fidelram) developer, plotting ☀ -- [Giovanni Palla](https://twitter.com/g_palla1), developer, spatial data -- [Malte Luecken](https://twitter.com/MDLuecken), developer, community & forum +- [Philipp Angerer](https://github.com/flying-sheep), lead developer since 2023, software quality, initial anndata conception ☀ +- [Ilan Gold](https://github.com/ilan-gold), developer, Dask ☀ +- [Severin Dicks](https://github.com/SeverinDicks), developer, performance ☀ - [Lukas Heumos](https://twitter.com/LukasHeumos), developer, diverse contributions -- [Philipp Angerer](https://github.com/flying-sheep), developer, software quality, initial anndata conception ☀ ## Other roles +- [Isaac Virshup](https://github.com/ivirshup), lead developer 2019-2023 - [Alex Wolf](https://twitter.com/falexwolf): lead developer 2016-2019, initial anndata & scanpy conception - [Fabian Theis](https://twitter.com/fabian_theis) & lab: enabling guidance, support and environment ## Former developers -- Tom White: developer 2018-2019, distributed computing +- [Tom White](https://github.com/tomwhite): developer 2018-2019, distributed computing +- [Gökcen Eraslan](https://twitter.com/gokcen), developer, diverse contributions +- [Sergei Rybakov](https://github.com/Koncopd), developer, diverse contributions +- [Fidel Ramirez](https://github.com/fidelram) developer, plotting +- [Giovanni Palla](https://twitter.com/g_palla1), developer, spatial data +- [Malte Luecken](https://twitter.com/MDLuecken), developer, community & forum diff --git a/pyproject.toml b/pyproject.toml index 8e23afb14b..00f02ad27c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,9 @@ authors = [ {name = "Andrés R. Muñoz-Rojas"}, ] maintainers = [ - {name = "Isaac Virshup", email = "ivirshup@gmail.com"}, {name = "Philipp Angerer", email = "phil.angerer@gmail.com"}, - {name = "Alex Wolf", email = "f.alex.wolf@gmx.de"}, + {name = "Ilan Gold"}, + {name = "Severin Dicks"}, ] readme = "README.md" classifiers = [ @@ -70,12 +70,14 @@ dependencies = [ ] dynamic = ["version"] +# https://docs.pypi.org/project_metadata/#project-urls [project.urls] Documentation = "https://scanpy.readthedocs.io/" Source = "https://github.com/scverse/scanpy" -Home-page = "https://scanpy.org" +Homepage = "https://scanpy.org" Discourse = "https://discourse.scverse.org/c/help/scanpy/37" -Twitter = "https://twitter.com/scverse_team" +Bluesky = "https://bsky.app/profile/scverse.bsky.social" +Twitter = "https://x.com/scverse_team" [project.scripts] scanpy = "scanpy.cli:console_main" From c26290b43cf7f2630ba512a04902f2f6988d4825 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 10 Jan 2025 08:52:06 +0100 Subject: [PATCH 15/24] Formatting (#3414) Co-authored-by: Ilan Gold --- .gitignore | 1 - .pre-commit-config.yaml | 15 ++-- .taplo.toml | 5 ++ .vscode/launch.json | 26 ++++++ .vscode/settings.json | 22 +++++ biome.jsonc | 21 +++++ hatch.toml | 24 ++--- pyproject.toml | 90 +++++++++---------- .../1.0.0/spatial/scalefactors_json.json | 7 +- 9 files changed, 146 insertions(+), 65 deletions(-) create mode 100644 .taplo.toml create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json create mode 100644 biome.jsonc diff --git a/.gitignore b/.gitignore index d21120ee95..65f9de7e0a 100644 --- a/.gitignore +++ b/.gitignore @@ -42,7 +42,6 @@ Thumbs.db # IDEs and editors /.idea/ -/.vscode/ # asv benchmark files /benchmarks/.asv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e891103226..e87eb88663 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,14 +3,11 @@ repos: rev: v0.8.6 hooks: - id: ruff - types_or: [python, pyi, jupyter] args: ["--fix"] - id: ruff-format - types_or: [python, pyi, jupyter] # The following can be removed once PLR0917 is out of preview - name: ruff preview rules id: ruff - types_or: [python, pyi, jupyter] args: ["--preview", "--select=PLR0917"] - repo: https://github.com/flying-sheep/bibfmt rev: v4.3.0 @@ -19,6 +16,15 @@ repos: args: - --sort-by-bibkey - --drop=abstract +- repo: https://github.com/biomejs/pre-commit + rev: v0.6.1 + hooks: + - id: biome-format + additional_dependencies: ["@biomejs/biome@1.9.4"] +- repo: https://github.com/ComPWA/taplo-pre-commit + rev: v0.9.3 + hooks: + - id: taplo-format - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: @@ -34,6 +40,3 @@ repos: - id: detect-private-key - id: no-commit-to-branch args: ["--branch=main"] - -ci: - autofix_prs: false diff --git a/.taplo.toml b/.taplo.toml new file mode 100644 index 0000000000..41a6cdc5cc --- /dev/null +++ b/.taplo.toml @@ -0,0 +1,5 @@ +[formatting] +array_auto_collapse = false +column_width = 120 +compact_arrays = false +indent_string = ' ' diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..d87ef7c54f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,26 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Build Documentation", + "type": "debugpy", + "request": "launch", + "module": "sphinx", + "args": ["-M", "html", ".", "_build"], + "cwd": "${workspaceFolder}/docs", + "console": "internalConsole", + "justMyCode": false, + }, + { + "name": "Python: Debug Test", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "internalConsole", + "justMyCode": false, + "env": { "PYTEST_ADDOPTS": "--color=yes" }, + "presentation": { "hidden": true }, + }, + ], +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..ae719a4ec8 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,22 @@ +{ + "[python][toml][json][jsonc]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.organizeImports": "explicit", + "source.fixAll": "explicit", + }, + }, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + }, + "[toml]": { + "editor.defaultFormatter": "tamasfe.even-better-toml", + }, + "[json][jsonc]": { + "editor.defaultFormatter": "biomejs.biome", + }, + "python.analysis.typeCheckingMode": "basic", + "python.testing.pytestArgs": ["-vv", "--color=yes"], + "python.testing.pytestEnabled": true, + "python.terminal.activateEnvironment": true, +} diff --git a/biome.jsonc b/biome.jsonc new file mode 100644 index 0000000000..cf4a677503 --- /dev/null +++ b/biome.jsonc @@ -0,0 +1,21 @@ +{ + "$schema": "https://biomejs.dev/schemas/1.9.4/schema.json", + "formatter": { + "indentStyle": "space", + "indentWidth": 4, + }, + "overrides": [ + { + "include": ["./.vscode/*.json", "**/*.jsonc", "**/asv.conf.json"], + "json": { + "formatter": { + "trailingCommas": "all", + }, + "parser": { + "allowComments": true, + "allowTrailingCommas": true, + }, + }, + }, + ], +} diff --git a/hatch.toml b/hatch.toml index 3163d5d82d..b0a1084c61 100644 --- a/hatch.toml +++ b/hatch.toml @@ -1,9 +1,9 @@ [envs.default] installer = "uv" -features = ["dev"] +features = [ "dev" ] [envs.docs] -features = ["doc"] +features = [ "doc" ] scripts.build = "sphinx-build -M html docs docs/_build -W --keep-going {args}" scripts.open = "python3 -m webbrowser -t docs/_build/html/index.html" scripts.clean = "git clean -fdX -- {args:docs}" @@ -14,23 +14,23 @@ scripts.build = "python3 ci/scripts/towncrier_automation.py {args}" scripts.clean = "git restore --source=HEAD --staged --worktree -- docs/release-notes" [envs.hatch-test] -default-args = [] -features = ["test", "dask-ml"] -extra-dependencies = ["ipykernel"] +default-args = [ ] +features = [ "test", "dask-ml" ] +extra-dependencies = [ "ipykernel" ] overrides.matrix.deps.env-vars = [ - { if = ["pre"], key = "UV_PRERELEASE", value = "allow" }, - { if = ["min"], key = "UV_CONSTRAINT", value = "ci/scanpy-min-deps.txt" }, + { if = [ "pre" ], key = "UV_PRERELEASE", value = "allow" }, + { if = [ "min" ], key = "UV_CONSTRAINT", value = "ci/scanpy-min-deps.txt" }, ] overrides.matrix.deps.pre-install-commands = [ - { if = ["min"], value = "uv run ci/scripts/min-deps.py pyproject.toml --all-extras -o ci/scanpy-min-deps.txt" }, + { if = [ "min" ], value = "uv run ci/scripts/min-deps.py pyproject.toml --all-extras -o ci/scanpy-min-deps.txt" }, ] overrides.matrix.deps.python = [ - { if = ["min"], value = "3.10" }, - { if = ["stable", "full", "pre"], value = "3.12" }, + { if = [ "min" ], value = "3.10" }, + { if = [ "stable", "full", "pre" ], value = "3.12" }, ] overrides.matrix.deps.features = [ - { if = ["full"], value = "test-full" }, + { if = [ "full" ], value = "test-full" }, ] [[envs.hatch-test.matrix]] -deps = ["stable", "full", "pre", "min"] +deps = [ "stable", "full", "pre", "min" ] diff --git a/pyproject.toml b/pyproject.toml index 00f02ad27c..94654363e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] build-backend = "hatchling.build" -requires = ["hatchling", "hatch-vcs"] +requires = [ "hatchling", "hatch-vcs" ] [project] name = "scanpy" @@ -8,23 +8,23 @@ description = "Single-Cell Analysis in Python." requires-python = ">=3.10" license = "BSD-3-clause" authors = [ - {name = "Alex Wolf"}, - {name = "Philipp Angerer"}, - {name = "Fidel Ramirez"}, - {name = "Isaac Virshup"}, - {name = "Sergei Rybakov"}, - {name = "Gokcen Eraslan"}, - {name = "Tom White"}, - {name = "Malte Luecken"}, - {name = "Davide Cittaro"}, - {name = "Tobias Callies"}, - {name = "Marius Lange"}, - {name = "Andrés R. Muñoz-Rojas"}, + { name = "Alex Wolf" }, + { name = "Philipp Angerer" }, + { name = "Fidel Ramirez" }, + { name = "Isaac Virshup" }, + { name = "Sergei Rybakov" }, + { name = "Gokcen Eraslan" }, + { name = "Tom White" }, + { name = "Malte Luecken" }, + { name = "Davide Cittaro" }, + { name = "Tobias Callies" }, + { name = "Marius Lange" }, + { name = "Andrés R. Muñoz-Rojas" }, ] maintainers = [ - {name = "Philipp Angerer", email = "phil.angerer@gmail.com"}, - {name = "Ilan Gold"}, - {name = "Severin Dicks"}, + { name = "Philipp Angerer", email = "phil.angerer@gmail.com" }, + { name = "Ilan Gold", email = "ilan.gold@helmholtz-munich.de" }, + { name = "Severin Dicks" }, ] readme = "README.md" classifiers = [ @@ -56,7 +56,7 @@ dependencies = [ "tqdm", "scikit-learn>=1.1,<1.6.0", "statsmodels>=0.13", - "patsy!=1.0.0", # https://github.com/pydata/patsy/issues/215 + "patsy!=1.0.0", # https://github.com/pydata/patsy/issues/215 "networkx>=2.7", "natsort", "joblib", @@ -65,10 +65,10 @@ dependencies = [ "pynndescent>=0.5", "packaging>=21.3", "session-info2", - "legacy-api-wrap>=1.4", # for positional API deprecations + "legacy-api-wrap>=1.4", # for positional API deprecations "typing-extensions; python_version < '3.13'", ] -dynamic = ["version"] +dynamic = [ "version" ] # https://docs.pypi.org/project_metadata/#project-urls [project.urls] @@ -120,13 +120,13 @@ doc = [ "sphinx-design", "sphinx-tabs", "readthedocs-sphinx-search", - "sphinxext-opengraph", # for nice cards when sharing on social + "sphinxext-opengraph", # for nice cards when sharing on social "sphinx-copybutton", "nbsphinx>=0.9", - "ipython>=7.20", # for nbsphinx code highlighting + "ipython>=7.20", # for nbsphinx code highlighting "matplotlib!=3.6.1", "sphinxcontrib-bibtex", - "setuptools", # undeclared dependency of sphinxcontrib-bibtex→pybtex + "setuptools", # undeclared dependency of sphinxcontrib-bibtex→pybtex # TODO: remove necessity for being able to import doc-linked classes "scanpy[paga,dask-ml]", "sam-algorithm", @@ -139,22 +139,22 @@ dev = [ "towncrier", ] # Algorithms -paga = ["igraph"] -louvain = ["igraph", "louvain>=0.6.0,!=0.6.2"] # Louvain community detection -leiden = ["igraph>=0.10", "leidenalg>=0.9.0"] # Leiden community detection -bbknn = ["bbknn"] # Batch balanced KNN (batch correction) -magic = ["magic-impute>=2.0"] # MAGIC imputation method -skmisc = ["scikit-misc>=0.1.3"] # highly_variable_genes method 'seurat_v3' -harmony = ["harmonypy"] # Harmony dataset integration -scanorama = ["scanorama"] # Scanorama dataset integration -scrublet = ["scikit-image"] # Doublet detection with automatic thresholds +paga = [ "igraph" ] +louvain = [ "igraph", "louvain>=0.6.0,!=0.6.2" ] # Louvain community detection +leiden = [ "igraph>=0.10", "leidenalg>=0.9.0" ] # Leiden community detection +bbknn = [ "bbknn" ] # Batch balanced KNN (batch correction) +magic = [ "magic-impute>=2.0" ] # MAGIC imputation method +skmisc = [ "scikit-misc>=0.1.3" ] # highly_variable_genes method 'seurat_v3' +harmony = [ "harmonypy" ] # Harmony dataset integration +scanorama = [ "scanorama" ] # Scanorama dataset integration +scrublet = [ "scikit-image" ] # Doublet detection with automatic thresholds # Acceleration -rapids = ["cudf>=0.9", "cuml>=0.9", "cugraph>=0.9"] # GPU accelerated calculation of neighbors -dask = ["dask[array]>=2022.09.2,<2024.8.0"] # Use the Dask parallelization engine -dask-ml = ["dask-ml", "scanpy[dask]"] # Dask-ML for sklearn-like API +rapids = [ "cudf>=0.9", "cuml>=0.9", "cugraph>=0.9" ] # GPU accelerated calculation of neighbors +dask = [ "dask[array]>=2022.09.2,<2024.8.0" ] # Use the Dask parallelization engine +dask-ml = [ "dask-ml", "scanpy[dask]" ] # Dask-ML for sklearn-like API [tool.hatch.build.targets.wheel] -packages = ["src/testing", "src/scanpy"] +packages = [ "src/testing", "src/scanpy" ] [tool.hatch.version] source = "vcs" raw-options.version_scheme = "release-branch-semver" @@ -169,8 +169,8 @@ addopts = [ "-ptesting.scanpy._pytest", "--pyargs", ] -testpaths = ["./tests", "./ci", "scanpy"] -norecursedirs = ["tests/_images"] +testpaths = [ "./tests", "./ci", "scanpy" ] +norecursedirs = [ "tests/_images" ] xfail_strict = true nunit_attach_on = "fail" markers = [ @@ -203,12 +203,12 @@ filterwarnings = [ [tool.coverage.run] data_file = "test-data/coverage" -source_pkgs = ["scanpy"] -omit = ["tests/*", "src/testing/*"] +source_pkgs = [ "scanpy" ] +omit = [ "tests/*", "src/testing/*" ] [tool.coverage.xml] output = "test-data/coverage.xml" [tool.coverage.paths] -source = [".", "**/site-packages"] +source = [ ".", "**/site-packages" ] [tool.coverage.report] exclude_also = [ "if __name__ == .__main__.:", @@ -218,7 +218,7 @@ exclude_also = [ ] [tool.ruff] -src = ["src"] +src = [ "src" ] [tool.ruff.format] docstring-code-format = true @@ -254,10 +254,10 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] # Do not assign a lambda expression, use a def -"src/scanpy/tools/_rank_genes_groups.py" = ["E731"] +"src/scanpy/tools/_rank_genes_groups.py" = [ "E731" ] [tool.ruff.lint.isort] -known-first-party = ["scanpy", "testing.scanpy"] -required-imports = ["from __future__ import annotations"] +known-first-party = [ "scanpy", "testing.scanpy" ] +required-imports = [ "from __future__ import annotations" ] [tool.ruff.lint.flake8-tidy-imports.banned-api] "pytest.importorskip".msg = "Use the “@needs” decorator/mark instead" "pandas.api.types.is_categorical_dtype".msg = "Use isinstance(s.dtype, CategoricalDtype) instead" @@ -267,7 +267,7 @@ required-imports = ["from __future__ import annotations"] "numba.jit".msg = "Use `scanpy._compat.njit` instead" "numba.njit".msg = "Use `scanpy._compat.njit` instead" [tool.ruff.lint.flake8-type-checking] -exempt-modules = [] +exempt-modules = [ ] strict = true [tool.towncrier] diff --git a/tests/_data/visium_data/1.0.0/spatial/scalefactors_json.json b/tests/_data/visium_data/1.0.0/spatial/scalefactors_json.json index 9f47f51518..5479b589c0 100644 --- a/tests/_data/visium_data/1.0.0/spatial/scalefactors_json.json +++ b/tests/_data/visium_data/1.0.0/spatial/scalefactors_json.json @@ -1 +1,6 @@ -{"spot_diameter_fullres": 89.42751063343188, "tissue_hires_scalef": 0.150015, "fiducial_diameter_fullres": 144.45982486939, "tissue_lowres_scalef": 0.045004502} \ No newline at end of file +{ + "spot_diameter_fullres": 89.42751063343188, + "tissue_hires_scalef": 0.150015, + "fiducial_diameter_fullres": 144.45982486939, + "tissue_lowres_scalef": 0.045004502 +} From 72b0b815f1e14bc3590f8648368ba95133c3727b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 10 Jan 2025 15:25:02 +0100 Subject: [PATCH 16/24] (chore) upper bound zarr version in tests to <3 (#3432) --- benchmarks/asv.conf.json | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 98192b3725..404e0b83dd 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -78,7 +78,7 @@ "natsort": [""], "pandas": [""], "memory_profiler": [""], - "zarr": [""], + "zarr": ["2.18.4"], "pytest": [""], "scanpy": [""], "python-igraph": [""], diff --git a/pyproject.toml b/pyproject.toml index 94654363e3..cd26d2f9a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ test = [ "scanpy[test-min]", # Optional but important dependencies "scanpy[leiden]", - "zarr", + "zarr<3", "scanpy[dask]", "scanpy[scrublet]", ] From 66f1b61dc6f1a0ae653dde3fa9ba835402f88c1a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 13 Jan 2025 10:36:35 +0100 Subject: [PATCH 17/24] Fix flaky doublet test (#3436) --- tests/external/test_hashsolo.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/external/test_hashsolo.py b/tests/external/test_hashsolo.py index a6f0a79971..d338d1b6ec 100644 --- a/tests/external/test_hashsolo.py +++ b/tests/external/test_hashsolo.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import pandas as pd from anndata import AnnData import scanpy.external as sce @@ -27,9 +28,11 @@ def test_cell_demultiplexing(): sce.pp.hashsolo(test_data, test_data.obs.columns) doublets = ["Doublet"] * 10 - classes = list( - np.repeat(np.arange(10), 98).reshape(98, 10, order="F").ravel().astype(str) - ) + classes = np.repeat(np.arange(10), 98).reshape(98, 10, order="F").ravel().tolist() negatives = ["Negative"] * 10 - classification = doublets + classes + negatives - assert test_data.obs["Classification"].astype(str).tolist() == classification + expected = pd.array(doublets + classes + negatives, dtype="string") + classification = test_data.obs["Classification"].array.astype("string") + # This is a bit flaky, so allow some mismatches: + if (expected != classification).sum() > 3: + # Compare lists for better error message + assert classification.tolist() == expected.tolist() From f7acd0234054642c755d306d50f50b6d791105f4 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 13 Jan 2025 13:00:01 +0100 Subject: [PATCH 18/24] (chore): Update to Ruff 0.9 and add EM lints (#3437) --- .pre-commit-config.yaml | 2 +- ci/scripts/min-deps.py | 4 +- docs/extensions/param_police.py | 3 +- pyproject.toml | 3 +- src/scanpy/__init__.py | 5 +- src/scanpy/_settings.py | 12 ++- src/scanpy/_utils/__init__.py | 77 ++++++++------- src/scanpy/_utils/_doctests.py | 3 +- src/scanpy/_utils/compute/is_constant.py | 6 +- .../experimental/pp/_highly_variable_genes.py | 17 ++-- src/scanpy/experimental/pp/_normalization.py | 9 +- src/scanpy/external/exporting.py | 8 +- src/scanpy/external/pl.py | 5 +- src/scanpy/external/pp/_bbknn.py | 3 +- src/scanpy/external/pp/_dca.py | 3 +- src/scanpy/external/pp/_harmony_integrate.py | 3 +- src/scanpy/external/pp/_hashsolo.py | 8 +- src/scanpy/external/pp/_magic.py | 12 ++- src/scanpy/external/pp/_mnn_correct.py | 6 +- .../external/pp/_scanorama_integrate.py | 6 +- src/scanpy/external/tl/_harmony_timeseries.py | 6 +- src/scanpy/external/tl/_palantir.py | 3 +- src/scanpy/external/tl/_phate.py | 5 +- src/scanpy/external/tl/_phenograph.py | 6 +- src/scanpy/external/tl/_pypairs.py | 6 +- src/scanpy/external/tl/_sam.py | 3 +- src/scanpy/external/tl/_trimap.py | 6 +- src/scanpy/external/tl/_wishbone.py | 11 ++- src/scanpy/get/_aggregated.py | 18 ++-- src/scanpy/get/get.py | 33 ++++--- src/scanpy/logging.py | 2 +- src/scanpy/metrics/_gearys_c.py | 3 +- src/scanpy/metrics/_morans_i.py | 3 +- src/scanpy/neighbors/__init__.py | 20 ++-- src/scanpy/plotting/_anndata.py | 97 +++++++++++-------- src/scanpy/plotting/_baseplot_class.py | 16 ++- src/scanpy/plotting/_dotplot.py | 10 +- src/scanpy/plotting/_scrublet.py | 5 +- src/scanpy/plotting/_stacked_violin.py | 2 +- src/scanpy/plotting/_tools/__init__.py | 53 +++++----- src/scanpy/plotting/_tools/paga.py | 41 +++++--- src/scanpy/plotting/_tools/scatterplots.py | 50 ++++++---- src/scanpy/plotting/_utils.py | 37 ++++--- src/scanpy/preprocessing/_combat.py | 14 +-- .../preprocessing/_deprecated/__init__.py | 3 +- .../_deprecated/highly_variable_genes.py | 6 +- .../preprocessing/_highly_variable_genes.py | 19 ++-- src/scanpy/preprocessing/_normalization.py | 14 +-- src/scanpy/preprocessing/_pca/__init__.py | 15 ++- src/scanpy/preprocessing/_qc.py | 8 +- src/scanpy/preprocessing/_recipes.py | 3 +- src/scanpy/preprocessing/_scale.py | 12 +-- .../preprocessing/_scrublet/pipeline.py | 6 +- src/scanpy/preprocessing/_simple.py | 37 ++++--- src/scanpy/preprocessing/_utils.py | 3 +- src/scanpy/queries/_queries.py | 16 +-- src/scanpy/readwrite.py | 77 +++++++++------ src/scanpy/tools/_dendrogram.py | 8 +- src/scanpy/tools/_diffmap.py | 8 +- src/scanpy/tools/_dpt.py | 20 ++-- src/scanpy/tools/_draw_graph.py | 3 +- src/scanpy/tools/_embedding_density.py | 12 ++- src/scanpy/tools/_ingest.py | 22 +++-- src/scanpy/tools/_leiden.py | 20 ++-- src/scanpy/tools/_louvain.py | 8 +- src/scanpy/tools/_marker_gene_overlap.py | 21 ++-- src/scanpy/tools/_paga.py | 22 +++-- src/scanpy/tools/_rank_genes_groups.py | 26 +++-- src/scanpy/tools/_score_genes.py | 14 +-- src/scanpy/tools/_sim.py | 36 +++---- src/scanpy/tools/_umap.py | 8 +- src/scanpy/tools/_utils.py | 16 +-- src/scanpy/tools/_utils_clustering.py | 8 +- src/testing/scanpy/_pytest/__init__.py | 9 +- tests/conftest.py | 3 +- tests/external/test_wishbone.py | 6 +- tests/test_dendrogram.py | 2 +- tests/test_get.py | 2 +- tests/test_highly_variable_genes.py | 3 +- tests/test_normalization.py | 12 +-- tests/test_rank_genes_groups.py | 14 ++- 81 files changed, 662 insertions(+), 505 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e87eb88663..c5e0e91d8c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.6 + rev: v0.9.1 hooks: - id: ruff args: ["--fix"] diff --git a/ci/scripts/min-deps.py b/ci/scripts/min-deps.py index 0d49d151ef..4efc304cb6 100755 --- a/ci/scripts/min-deps.py +++ b/ci/scripts/min-deps.py @@ -71,7 +71,9 @@ def extract_min_deps( # If we are referring to other optional dependency lists, resolve them if req.name == project_name: - assert req.extras, f"Project included itself as dependency, without specifying extras: {req}" + assert req.extras, ( + f"Project included itself as dependency, without specifying extras: {req}" + ) for extra in req.extras: extra_deps = pyproject["project"]["optional-dependencies"][extra] dependencies += map(Requirement, extra_deps) diff --git a/docs/extensions/param_police.py b/docs/extensions/param_police.py index 37942d3687..234ad28e62 100644 --- a/docs/extensions/param_police.py +++ b/docs/extensions/param_police.py @@ -37,7 +37,8 @@ def show_param_warnings(app, exception): line, ) if param_warnings: - raise RuntimeError("Encountered text parameter type. Use annotations.") + msg = "Encountered text parameter type. Use annotations." + raise RuntimeError(msg) def setup(app: Sphinx): diff --git a/pyproject.toml b/pyproject.toml index cd26d2f9a2..71f7f1c482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -230,7 +230,7 @@ select = [ "W", # Warning detected by Pycodestyle "UP", # pyupgrade "I", # isort - "TCH", # manage type checking blocks + "TC", # manage type checking blocks "TID251", # Banned imports "ICN", # Follow import conventions "PTH", # Pathlib instead of os.path @@ -239,6 +239,7 @@ select = [ "FBT", # No positional boolean parameters "PT", # Pytest style "SIM", # Simplify control flow + "EM", # Traceback-friendly error messages ] ignore = [ # line too long -> we accept long comment lines; black gets rid of long code lines diff --git a/src/scanpy/__init__.py b/src/scanpy/__init__.py index bbcc86437b..b844372d1e 100644 --- a/src/scanpy/__init__.py +++ b/src/scanpy/__init__.py @@ -15,9 +15,8 @@ try: from ._version import __version__ except ModuleNotFoundError: - raise RuntimeError( - "scanpy is not correctly installed. Please install it, e.g. with pip." - ) + msg = "scanpy is not correctly installed. Please install it, e.g. with pip." + raise RuntimeError(msg) from ._utils import check_versions diff --git a/src/scanpy/_settings.py b/src/scanpy/_settings.py index 5543689ef7..fc9ead09b0 100644 --- a/src/scanpy/_settings.py +++ b/src/scanpy/_settings.py @@ -83,7 +83,8 @@ def _type_check(var: Any, varname: str, types: type | tuple[type, ...]): else: type_names = [t.__name__ for t in types] possible_types_str = f"{', '.join(type_names[:-1])} or {type_names[-1]}" - raise TypeError(f"{varname} must be of type {possible_types_str}") + msg = f"{varname} must be of type {possible_types_str}" + raise TypeError(msg) class ScanpyConfig: @@ -180,10 +181,11 @@ def verbosity(self, verbosity: Verbosity | int | str): elif isinstance(verbosity, str): verbosity = verbosity.lower() if verbosity not in verbosity_str_options: - raise ValueError( + msg = ( f"Cannot set verbosity to {verbosity}. " f"Accepted string values are: {verbosity_str_options}" ) + raise ValueError(msg) else: self._verbosity = Verbosity(verbosity_str_options.index(verbosity)) else: @@ -214,10 +216,11 @@ def file_format_data(self, file_format: str): _type_check(file_format, "file_format_data", str) file_format_options = {"txt", "csv", "h5ad"} if file_format not in file_format_options: - raise ValueError( + msg = ( f"Cannot set file_format_data to {file_format}. " f"Must be one of {file_format_options}" ) + raise ValueError(msg) self._file_format_data = file_format @property @@ -322,10 +325,11 @@ def cache_compression(self) -> str | None: @cache_compression.setter def cache_compression(self, cache_compression: str | None): if cache_compression not in {"lzf", "gzip", None}: - raise ValueError( + msg = ( f"`cache_compression` ({cache_compression}) " "must be in {'lzf', 'gzip', None}" ) + raise ValueError(msg) self._cache_compression = cache_compression @property diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 67e2ae03c8..326ea216d1 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -93,11 +93,12 @@ def __getattr__(self, attr: str): def ensure_igraph() -> None: if importlib.util.find_spec("igraph"): return - raise ImportError( + msg = ( "Please install the igraph package: " "`conda install -c conda-forge python-igraph` or " "`pip3 install igraph`." ) + raise ImportError(msg) @contextmanager @@ -120,10 +121,11 @@ def check_versions(): if Version(anndata_version) < Version("0.6.10"): from .. import __version__ - raise ImportError( + msg = ( f"Scanpy {__version__} needs anndata version >=0.6.10, " f"not {anndata_version}.\nRun `pip install anndata -U --no-deps`." ) + raise ImportError(msg) def getdoc(c_or_f: Callable | type) -> str | None: @@ -195,7 +197,8 @@ def _import_name(name: str) -> Any: try: obj = getattr(obj, name) except AttributeError: - raise RuntimeError(f"{parts[:i]}, {parts[i + 1:]}, {obj} {name}") + msg = f"{parts[:i]}, {parts[i + 1 :]}, {obj} {name}" + raise RuntimeError(msg) return obj @@ -255,9 +258,8 @@ def _check_array_function_arguments(**kwargs): # TODO: Figure out a better solution for documenting dispatched functions invalid_args = [k for k, v in kwargs.items() if v is not None] if len(invalid_args) > 0: - raise TypeError( - f"Arguments {invalid_args} are only valid if an AnnData object is passed." - ) + msg = f"Arguments {invalid_args} are only valid if an AnnData object is passed." + raise TypeError(msg) def _check_use_raw( @@ -350,16 +352,14 @@ def compute_association_matrix_of_groups( reference labels, entries are proportional to degree of association. """ if normalization not in {"prediction", "reference"}: - raise ValueError( - '`normalization` needs to be either "prediction" or "reference".' - ) + msg = '`normalization` needs to be either "prediction" or "reference".' + raise ValueError(msg) sanitize_anndata(adata) cats = adata.obs[reference].cat.categories for cat in cats: if cat in settings.categories_to_ignore: logg.info( - f"Ignoring category {cat!r} " - "as it’s in `settings.categories_to_ignore`." + f"Ignoring category {cat!r} as it’s in `settings.categories_to_ignore`." ) asso_names: list[str] = [] asso_matrix: list[list[float]] = [] @@ -604,7 +604,8 @@ def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T: def check_op(op): if op not in {truediv, mul}: - raise ValueError(f"{op} not one of truediv or mul") + msg = f"{op} not one of truediv or mul" + raise ValueError(msg) @singledispatch @@ -639,9 +640,8 @@ def _( ) -> sparse.csr_matrix | sparse.csc_matrix: check_op(op) if out is not None and X.data is not out.data: - raise ValueError( - "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling." - ) + msg = "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling." + raise ValueError(msg) if not allow_divide_by_zero and op is truediv: scaling_array = scaling_array.copy() + (scaling_array == 0) @@ -697,9 +697,8 @@ def _( ) -> DaskArray: check_op(op) if out is not None: - raise TypeError( - "`out` is not `None`. Do not do in-place modifications on dask arrays." - ) + msg = "`out` is not `None`. Do not do in-place modifications on dask arrays." + raise TypeError(msg) import dask.array as da @@ -805,9 +804,8 @@ def sum_drop_keepdims(*args, **kwargs): axis = kwargs["axis"] if isinstance(axis, tuple): if len(axis) != 1: - raise ValueError( - f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead" - ) + msg = f"`axis_sum` can only sum over one axis when `axis` arg is provided but got {axis} instead" + raise ValueError(msg) kwargs["axis"] = axis[0] # returns a np.matrix normally, which is undesireable return np.array(np.sum(*args, dtype=dtype, **kwargs)) @@ -959,7 +957,8 @@ def subsample( Xsampled = np.array(X[rows]) else: if seed < 0: - raise ValueError(f"Invalid seed value < 0: {seed}") + msg = f"Invalid seed value < 0: {seed}" + raise ValueError(msg) n = int(X.shape[0] / subsample) np.random.seed(seed) Xsampled, rows = subsample_n(X, n=n) @@ -989,7 +988,8 @@ def subsample_n( Indices of rows that are stored in Xsampled. """ if n < 0: - raise ValueError("n must be greater 0") + msg = "n must be greater 0" + raise ValueError(msg) np.random.seed(seed) n = X.shape[0] if (n == 0 or n > X.shape[0]) else n rows = np.random.choice(X.shape[0], size=n, replace=False) @@ -1069,13 +1069,15 @@ def __init__(self, adata: AnnData, key=None): if key is None or key == "neighbors": if "neighbors" not in adata.uns: - raise KeyError('No "neighbors" in .uns') + msg = 'No "neighbors" in .uns' + raise KeyError(msg) self._neighbors_dict = adata.uns["neighbors"] self._conns_key = "connectivities" self._dists_key = "distances" else: if key not in adata.uns: - raise KeyError(f'No "{key}" in .uns') + msg = f"No {key!r} in .uns" + raise KeyError(msg) self._neighbors_dict = adata.uns[key] self._conns_key = self._neighbors_dict["connectivities_key"] self._dists_key = self._neighbors_dict["distances_key"] @@ -1108,11 +1110,13 @@ def __getitem__(self, key: Literal["connectivities_key"]) -> str: ... def __getitem__(self, key: str): if key == "distances": if "distances" not in self: - raise KeyError(f'No "{self._dists_key}" in .obsp') + msg = f"No {self._dists_key!r} in .obsp" + raise KeyError(msg) return self._distances elif key == "connectivities": if "connectivities" not in self: - raise KeyError(f'No "{self._conns_key}" in .obsp') + msg = f"No {self._conns_key!r} in .obsp" + raise KeyError(msg) return self._connectivities elif key == "connectivities_key": return self._conns_key @@ -1131,19 +1135,18 @@ def __contains__(self, key: str) -> bool: def _choose_graph(adata, obsp, neighbors_key): """Choose connectivities from neighbbors or another obsp column""" if obsp is not None and neighbors_key is not None: - raise ValueError( - "You can't specify both obsp, neighbors_key. " "Please select only one." - ) + msg = "You can't specify both obsp, neighbors_key. Please select only one." + raise ValueError(msg) if obsp is not None: return adata.obsp[obsp] else: neighbors = NeighborsView(adata, neighbors_key) if "connectivities" not in neighbors: - raise ValueError( - "You need to run `pp.neighbors` first " - "to compute a neighborhood graph." + msg = ( + "You need to run `pp.neighbors` first to compute a neighborhood graph." ) + raise ValueError(msg) return neighbors["connectivities"] @@ -1154,7 +1157,8 @@ def _resolve_axis( return (0, "obs") if axis in {1, "var"}: return (1, "var") - raise ValueError(f"`axis` must be either 0, 1, 'obs', or 'var', was {axis!r}") + msg = f"`axis` must be either 0, 1, 'obs', or 'var', was {axis!r}" + raise ValueError(msg) def is_backed_type(X: object) -> bool: @@ -1163,6 +1167,5 @@ def is_backed_type(X: object) -> bool: def raise_not_implemented_error_if_backed_type(X: object, method_name: str) -> None: if is_backed_type(X): - raise NotImplementedError( - f"{method_name} is not implemented for matrices of type {type(X)}" - ) + msg = f"{method_name} is not implemented for matrices of type {type(X)}" + raise NotImplementedError(msg) diff --git a/src/scanpy/_utils/_doctests.py b/src/scanpy/_utils/_doctests.py index 6a08099a24..0b3be18bbe 100644 --- a/src/scanpy/_utils/_doctests.py +++ b/src/scanpy/_utils/_doctests.py @@ -19,7 +19,8 @@ def decorator(func: F) -> F: def doctest_skip(reason: str) -> Callable[[F], F]: """Mark function so doctest is skipped.""" if not reason: - raise ValueError("reason must not be empty") + msg = "reason must not be empty" + raise ValueError(msg) def decorator(func: F) -> F: func._doctest_skip_reason = reason diff --git a/src/scanpy/_utils/compute/is_constant.py b/src/scanpy/_utils/compute/is_constant.py index 1bc147d68e..c9fac4abf0 100644 --- a/src/scanpy/_utils/compute/is_constant.py +++ b/src/scanpy/_utils/compute/is_constant.py @@ -24,9 +24,11 @@ def _check_axis_supported(wrapped: C) -> C: def func(a, axis=None): if axis is not None: if not isinstance(axis, Integral): - raise TypeError("axis must be integer or None.") + msg = "axis must be integer or None." + raise TypeError(msg) if axis not in (0, 1): - raise NotImplementedError("We only support axis 0 and 1 at the moment") + msg = "We only support axis 0 and 1 at the moment" + raise NotImplementedError(msg) return wrapped(a, axis) return func diff --git a/src/scanpy/experimental/pp/_highly_variable_genes.py b/src/scanpy/experimental/pp/_highly_variable_genes.py index ab78f0a74a..7ad9f36bd7 100644 --- a/src/scanpy/experimental/pp/_highly_variable_genes.py +++ b/src/scanpy/experimental/pp/_highly_variable_genes.py @@ -159,7 +159,8 @@ def _highly_variable_pearson_residuals( if theta <= 0: # TODO: would "underdispersion" with negative theta make sense? # then only theta=0 were undefined.. - raise ValueError("Pearson residuals require theta > 0") + msg = "Pearson residuals require theta > 0" + raise ValueError(msg) # prepare clipping if batch_key is None: @@ -185,7 +186,8 @@ def _highly_variable_pearson_residuals( n = X_batch.shape[0] clip = np.sqrt(n) if clip < 0: - raise ValueError("Pearson residuals require `clip>=0` or `clip=None`.") + msg = "Pearson residuals require `clip>=0` or `clip=None`." + raise ValueError(msg) if sp_sparse.issparse(X_batch): X_batch = X_batch.tocsc() @@ -378,17 +380,19 @@ def highly_variable_genes( logg.info("extracting highly variable genes") if not isinstance(adata, AnnData): - raise ValueError( + msg = ( "`pp.highly_variable_genes` expects an `AnnData` argument, " "pass `inplace=False` if you want to return a `pd.DataFrame`." ) + raise ValueError(msg) if flavor == "pearson_residuals": if n_top_genes is None: - raise ValueError( + msg = ( "`pp.highly_variable_genes` requires the argument `n_top_genes`" " for `flavor='pearson_residuals'`" ) + raise ValueError(msg) return _highly_variable_pearson_residuals( adata, layer=layer, @@ -402,6 +406,5 @@ def highly_variable_genes( inplace=inplace, ) else: - raise ValueError( - "This is an experimental API and only `flavor=pearson_residuals` is available." - ) + msg = "This is an experimental API and only `flavor=pearson_residuals` is available." + raise ValueError(msg) diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index bc4dedbaf9..ef3d0311d7 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -42,13 +42,15 @@ def _pearson_residuals(X, theta, clip, check_values, *, copy: bool = False): if theta <= 0: # TODO: would "underdispersion" with negative theta make sense? # then only theta=0 were undefined.. - raise ValueError("Pearson residuals require theta > 0") + msg = "Pearson residuals require theta > 0" + raise ValueError(msg) # prepare clipping if clip is None: n = X.shape[0] clip = np.sqrt(n) if clip < 0: - raise ValueError("Pearson residuals require `clip>=0` or `clip=None`.") + msg = "Pearson residuals require `clip>=0` or `clip=None`." + raise ValueError(msg) if check_values and not check_nonnegative_integers(X): warn( @@ -128,7 +130,8 @@ def normalize_pearson_residuals( if copy: if not inplace: - raise ValueError("`copy=True` cannot be used with `inplace=False`.") + msg = "`copy=True` cannot be used with `inplace=False`." + raise ValueError(msg) adata = adata.copy() view_to_actual(adata) diff --git a/src/scanpy/external/exporting.py b/src/scanpy/external/exporting.py index 9364b7d368..8379720ea6 100644 --- a/src/scanpy/external/exporting.py +++ b/src/scanpy/external/exporting.py @@ -86,7 +86,8 @@ def spring_project( neighbors_key = "neighbors" if neighbors_key not in adata.uns: - raise ValueError("Run `sc.pp.neighbors` first.") + msg = "Run `sc.pp.neighbors` first." + raise ValueError(msg) # check that requested 2-D embedding has been generated if embedding_method not in adata.obsm_keys(): @@ -101,9 +102,8 @@ def spring_project( + adata.uns[embedding_method]["params"]["layout"] ) else: - raise ValueError( - f"Run the specified embedding method `{embedding_method}` first." - ) + msg = f"Run the specified embedding method `{embedding_method}` first." + raise ValueError(msg) coords = adata.obsm[embedding_method] diff --git a/src/scanpy/external/pl.py b/src/scanpy/external/pl.py index a6ad48f718..ce305e2f06 100644 --- a/src/scanpy/external/pl.py +++ b/src/scanpy/external/pl.py @@ -198,9 +198,8 @@ def sam( try: dt = adata.obsm[projection] except KeyError: - raise ValueError( - "Please create a projection first using run_umap or run_tsne" - ) + msg = "Please create a projection first using run_umap or run_tsne" + raise ValueError(msg) else: dt = projection diff --git a/src/scanpy/external/pp/_bbknn.py b/src/scanpy/external/pp/_bbknn.py index 07d6e41f93..ee280cc824 100644 --- a/src/scanpy/external/pp/_bbknn.py +++ b/src/scanpy/external/pp/_bbknn.py @@ -133,7 +133,8 @@ def bbknn( try: from bbknn import bbknn except ImportError: - raise ImportError("Please install bbknn: `pip install bbknn`.") + msg = "Please install bbknn: `pip install bbknn`." + raise ImportError(msg) return bbknn( adata=adata, batch_key=batch_key, diff --git a/src/scanpy/external/pp/_dca.py b/src/scanpy/external/pp/_dca.py index c47fff90f2..20a97034b8 100644 --- a/src/scanpy/external/pp/_dca.py +++ b/src/scanpy/external/pp/_dca.py @@ -181,7 +181,8 @@ def dca( try: from dca.api import dca except ImportError: - raise ImportError("Please install dca package (>= 0.2.1) via `pip install dca`") + msg = "Please install dca package (>= 0.2.1) via `pip install dca`" + raise ImportError(msg) return dca( adata, diff --git a/src/scanpy/external/pp/_harmony_integrate.py b/src/scanpy/external/pp/_harmony_integrate.py index 1104690d53..824309f817 100644 --- a/src/scanpy/external/pp/_harmony_integrate.py +++ b/src/scanpy/external/pp/_harmony_integrate.py @@ -91,7 +91,8 @@ def harmony_integrate( try: import harmonypy except ImportError: - raise ImportError("\nplease install harmonypy:\n\n\tpip install harmonypy") + msg = "\nplease install harmonypy:\n\n\tpip install harmonypy" + raise ImportError(msg) X = adata.obsm[basis].astype(np.float64) diff --git a/src/scanpy/external/pp/_hashsolo.py b/src/scanpy/external/pp/_hashsolo.py index 256c863eee..dcb44239b1 100644 --- a/src/scanpy/external/pp/_hashsolo.py +++ b/src/scanpy/external/pp/_hashsolo.py @@ -352,15 +352,15 @@ def hashsolo( adata = adata.copy() if not inplace else adata data = adata.obs[cell_hashing_columns].values if not check_nonnegative_integers(data): - raise ValueError("Cell hashing counts must be non-negative") + msg = "Cell hashing counts must be non-negative" + raise ValueError(msg) if (number_of_noise_barcodes is not None) and ( number_of_noise_barcodes >= len(cell_hashing_columns) ): - raise ValueError( - "number_of_noise_barcodes must be at least one less \ + msg = "number_of_noise_barcodes must be at least one less \ than the number of samples you have as determined by the number of \ cell_hashing_columns you've given as input " - ) + raise ValueError(msg) num_of_cells = adata.shape[0] results = pd.DataFrame( np.zeros((num_of_cells, 6)), diff --git a/src/scanpy/external/pp/_magic.py b/src/scanpy/external/pp/_magic.py index 132d2a6448..12e93f1a8e 100644 --- a/src/scanpy/external/pp/_magic.py +++ b/src/scanpy/external/pp/_magic.py @@ -142,34 +142,38 @@ def magic( try: from magic import MAGIC, __version__ except ImportError: - raise ImportError( + msg = ( "Please install magic package via `pip install --user " "git+git://github.com/KrishnaswamyLab/MAGIC.git#subdirectory=python`" ) + raise ImportError(msg) else: if Version(__version__) < Version(MIN_VERSION): - raise ImportError( + msg = ( "scanpy requires magic-impute >= " f"v{MIN_VERSION} (detected: v{__version__}). " "Please update magic package via `pip install --user " "--upgrade magic-impute`" ) + raise ImportError(msg) start = logg.info("computing MAGIC") all_or_pca = isinstance(name_list, str | NoneType) if all_or_pca and name_list not in {"all_genes", "pca_only", None}: - raise ValueError( + msg = ( "Invalid string value for `name_list`: " "Only `'all_genes'` and `'pca_only'` are allowed." ) + raise ValueError(msg) if copy is None: copy = not all_or_pca elif not all_or_pca and not copy: - raise ValueError( + msg = ( "Can only perform MAGIC in-place with `name_list=='all_genes' or " f"`name_list=='pca_only'` (got {name_list}). Consider setting " "`copy=True`" ) + raise ValueError(msg) adata = adata.copy() if copy else adata n_jobs = settings.n_jobs if n_jobs is None else n_jobs diff --git a/src/scanpy/external/pp/_mnn_correct.py b/src/scanpy/external/pp/_mnn_correct.py index a497189913..518686dc75 100644 --- a/src/scanpy/external/pp/_mnn_correct.py +++ b/src/scanpy/external/pp/_mnn_correct.py @@ -133,10 +133,8 @@ def mnn_correct( import mnnpy from mnnpy import mnn_correct except ImportError: - raise ImportError( - "Please install the package mnnpy " - "(https://github.com/chriscainx/mnnpy). " - ) + msg = "Please install the package mnnpy (https://github.com/chriscainx/mnnpy). " + raise ImportError(msg) n_jobs = settings.n_jobs if n_jobs is None else n_jobs diff --git a/src/scanpy/external/pp/_scanorama_integrate.py b/src/scanpy/external/pp/_scanorama_integrate.py index ca847f8351..c5fb2683b4 100644 --- a/src/scanpy/external/pp/_scanorama_integrate.py +++ b/src/scanpy/external/pp/_scanorama_integrate.py @@ -111,7 +111,8 @@ def scanorama_integrate( try: import scanorama except ImportError: - raise ImportError("\nplease install Scanorama:\n\n\tpip install scanorama") + msg = "\nplease install Scanorama:\n\n\tpip install scanorama" + raise ImportError(msg) # Get batch indices in linear time. curr_batch = None @@ -123,7 +124,8 @@ def scanorama_integrate( curr_batch = batch_name if batch_name in batch_names: # Contiguous batches important for preserving cell order. - raise ValueError("Detected non-contiguous batches.") + msg = "Detected non-contiguous batches." + raise ValueError(msg) batch_names.append(batch_name) # Preserve name order. name2idx[batch_name] = [] name2idx[batch_name].append(idx) diff --git a/src/scanpy/external/tl/_harmony_timeseries.py b/src/scanpy/external/tl/_harmony_timeseries.py index d1746af45a..de3f8cde26 100644 --- a/src/scanpy/external/tl/_harmony_timeseries.py +++ b/src/scanpy/external/tl/_harmony_timeseries.py @@ -140,13 +140,15 @@ def harmony_timeseries( try: import harmony except ImportError: - raise ImportError("\nplease install harmony:\n\n\tpip install harmonyTS") + msg = "\nplease install harmony:\n\n\tpip install harmonyTS" + raise ImportError(msg) adata = adata.copy() if copy else adata logg.info("Harmony augmented affinity matrix") if adata.obs[tp].dtype.name != "category": - raise ValueError(f"{tp!r} column does not contain Categorical data") + msg = f"{tp!r} column does not contain Categorical data" + raise ValueError(msg) timepoints = adata.obs[tp].cat.categories.tolist() timepoint_connections = pd.DataFrame(np.array([timepoints[:-1], timepoints[1:]]).T) diff --git a/src/scanpy/external/tl/_palantir.py b/src/scanpy/external/tl/_palantir.py index 854301466a..eb060bbbe0 100644 --- a/src/scanpy/external/tl/_palantir.py +++ b/src/scanpy/external/tl/_palantir.py @@ -340,4 +340,5 @@ def _check_import(): try: import palantir # noqa: F401 except ImportError: - raise ImportError("\nplease install palantir:\n\tpip install palantir") + msg = "\nplease install palantir:\n\tpip install palantir" + raise ImportError(msg) diff --git a/src/scanpy/external/tl/_phate.py b/src/scanpy/external/tl/_phate.py index ff50a1e6f7..91d8191e60 100644 --- a/src/scanpy/external/tl/_phate.py +++ b/src/scanpy/external/tl/_phate.py @@ -154,10 +154,11 @@ def phate( try: import phate except ImportError: - raise ImportError( + msg = ( "You need to install the package `phate`: please run `pip install " "--user phate` in a terminal." ) + raise ImportError(msg) X_phate = phate.PHATE( n_components=n_components, k=k, @@ -179,6 +180,6 @@ def phate( logg.info( " finished", time=start, - deep=("added\n" " 'X_phate', PHATE coordinates (adata.obsm)"), + deep=("added\n 'X_phate', PHATE coordinates (adata.obsm)"), ) return adata if copy else None diff --git a/src/scanpy/external/tl/_phenograph.py b/src/scanpy/external/tl/_phenograph.py index 24e10bcb85..fdc3973771 100644 --- a/src/scanpy/external/tl/_phenograph.py +++ b/src/scanpy/external/tl/_phenograph.py @@ -226,17 +226,19 @@ def phenograph( assert phenograph.__version__ >= "1.5.3" except (ImportError, AssertionError, AttributeError): - raise ImportError( + msg = ( "please install the latest release of phenograph:\n\t" "pip install -U PhenoGraph" ) + raise ImportError(msg) if isinstance(data, AnnData): adata = data try: data = data.obsm["X_pca"] except KeyError: - raise KeyError("Please run `sc.pp.pca` on `data` and try again!") + msg = "Please run `sc.pp.pca` on `data` and try again!" + raise KeyError(msg) else: adata = None copy = True diff --git a/src/scanpy/external/tl/_pypairs.py b/src/scanpy/external/tl/_pypairs.py index 255334fe7a..2db98ff9a7 100644 --- a/src/scanpy/external/tl/_pypairs.py +++ b/src/scanpy/external/tl/_pypairs.py @@ -153,8 +153,10 @@ def _check_import(): try: import pypairs except ImportError: - raise ImportError("You need to install the package `pypairs`.") + msg = "You need to install the package `pypairs`." + raise ImportError(msg) min_version = Version("3.0.9") if Version(pypairs.__version__) < min_version: - raise ImportError(f"Please only use `pypairs` >= {min_version}") + msg = f"Please only use `pypairs` >= {min_version}" + raise ImportError(msg) diff --git a/src/scanpy/external/tl/_sam.py b/src/scanpy/external/tl/_sam.py index ebf3156b9a..8daa2c0091 100644 --- a/src/scanpy/external/tl/_sam.py +++ b/src/scanpy/external/tl/_sam.py @@ -211,12 +211,13 @@ def sam( try: from samalg import SAM except ImportError: - raise ImportError( + msg = ( "\nplease install sam-algorithm: \n\n" "\tgit clone git://github.com/atarashansky/self-assembling-manifold.git\n" "\tcd self-assembling-manifold\n" "\tpip install ." ) + raise ImportError(msg) logg.info("Self-assembling manifold") diff --git a/src/scanpy/external/tl/_trimap.py b/src/scanpy/external/tl/_trimap.py index 9146e79b84..122a4792b7 100644 --- a/src/scanpy/external/tl/_trimap.py +++ b/src/scanpy/external/tl/_trimap.py @@ -108,7 +108,8 @@ def trimap( try: from trimap import TRIMAP except ImportError: - raise ImportError("\nplease install trimap: \n\n\tsudo pip install trimap") + msg = "\nplease install trimap: \n\n\tsudo pip install trimap" + raise ImportError(msg) adata = adata.copy() if copy else adata start = logg.info("computing TriMap") adata = adata.copy() if copy else adata @@ -121,10 +122,11 @@ def trimap( else: X = adata.X if scp.issparse(X): - raise ValueError( + msg = ( "trimap currently does not support sparse matrices. Please" "use a dense matrix or apply pca first." ) + raise ValueError(msg) logg.warning("`X_pca` not found. Run `sc.pp.pca` first for speedup.") X_trimap = TRIMAP( n_dims=n_components, diff --git a/src/scanpy/external/tl/_wishbone.py b/src/scanpy/external/tl/_wishbone.py index e857226feb..3b85ae14a1 100644 --- a/src/scanpy/external/tl/_wishbone.py +++ b/src/scanpy/external/tl/_wishbone.py @@ -104,17 +104,17 @@ def wishbone( try: from wishbone.core import wishbone as c_wishbone except ImportError: - raise ImportError( - "\nplease install wishbone:\n\n\thttps://github.com/dpeerlab/wishbone" - ) + msg = "\nplease install wishbone:\n\n\thttps://github.com/dpeerlab/wishbone" + raise ImportError(msg) # Start cell index s = np.where(adata.obs_names == start_cell)[0] if len(s) == 0: - raise RuntimeError( + msg = ( f"Start cell {start_cell} not found in data. " "Please rerun with correct start cell." ) + raise RuntimeError(msg) if isinstance(num_waypoints, Collection): diff = np.setdiff1d(num_waypoints, adata.obs.index) if diff.size > 0: @@ -124,10 +124,11 @@ def wishbone( ) num_waypoints = diff.tolist() elif num_waypoints > adata.shape[0]: - raise RuntimeError( + msg = ( "num_waypoints parameter is higher than the number of cells in the " "dataset. Please select a smaller number" ) + raise RuntimeError(msg) s = s[0] # Run the algorithm diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 53a18bb47c..94bf202b69 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -256,25 +256,29 @@ def aggregate( Note that this filters out any combination of groups that wasn't present in the original data. """ if not isinstance(adata, AnnData): - raise NotImplementedError( + msg = ( "sc.get.aggregate is currently only implemented for AnnData input, " f"was passed {type(adata)}." ) + raise NotImplementedError(msg) if axis is None: axis = 1 if varm else 0 axis, axis_name = _resolve_axis(axis) mask = _check_mask(adata, mask, axis_name) data = adata.X if sum(p is not None for p in [varm, obsm, layer]) > 1: - raise TypeError("Please only provide one (or none) of varm, obsm, or layer") + msg = "Please only provide one (or none) of varm, obsm, or layer" + raise TypeError(msg) if varm is not None: if axis != 1: - raise ValueError("varm can only be used when axis is 1") + msg = "varm can only be used when axis is 1" + raise ValueError(msg) data = adata.varm[varm] elif obsm is not None: if axis != 0: - raise ValueError("obsm can only be used when axis is 0") + msg = "obsm can only be used when axis is 0" + raise ValueError(msg) data = adata.obsm[obsm] elif layer is not None: data = adata.layers[layer] @@ -324,7 +328,8 @@ def _aggregate( mask: NDArray[np.bool_] | None = None, dof: int = 1, ): - raise NotImplementedError(f"Data type {type(data)} not supported for aggregation") + msg = f"Data type {type(data)} not supported for aggregation" + raise NotImplementedError(msg) @_aggregate.register(pd.DataFrame) @@ -347,7 +352,8 @@ def aggregate_array( funcs = set([func] if isinstance(func, str) else func) if unknown := funcs - get_literal_vals(AggType): - raise ValueError(f"func {unknown} is not one of {get_literal_vals(AggType)}") + msg = f"func {unknown} is not one of {get_literal_vals(AggType)}" + raise ValueError(msg) if "sum" in funcs: # sum is calculated separately from the rest agg = groupby.sum() diff --git a/src/scanpy/get/get.py b/src/scanpy/get/get.py index c36ddde8f8..abfa51d1f9 100644 --- a/src/scanpy/get/get.py +++ b/src/scanpy/get/get.py @@ -149,18 +149,20 @@ def _check_indices( # be further duplicated when selecting them. if not dim_df.columns.is_unique: dup_cols = dim_df.columns[dim_df.columns.duplicated()].tolist() - raise ValueError( + msg = ( f"adata.{dim} contains duplicated columns. Please rename or remove " "these columns first.\n`" f"Duplicated columns {dup_cols}" ) + raise ValueError(msg) if not alt_index.is_unique: - raise ValueError( + msg = ( f"{alt_repr}.{alt_dim}_names contains duplicated items\n" f"Please rename these {alt_dim} names first for example using " f"`adata.{alt_dim}_names_make_unique()`" ) + raise ValueError(msg) # use only unique keys, otherwise duplicated keys will # further duplicate when reordering the keys later in the function @@ -168,27 +170,26 @@ def _check_indices( if key in dim_df.columns: col_keys.append(key) if key in alt_names.index: - raise KeyError( - f"The key '{key}' is found in both adata.{dim} and {alt_repr}.{alt_search_repr}." - ) + msg = f"The key {key!r} is found in both adata.{dim} and {alt_repr}.{alt_search_repr}." + raise KeyError(msg) elif key in alt_names.index: val = alt_names[key] if isinstance(val, pd.Series): # while var_names must be unique, adata.var[gene_symbols] does not # It's still ambiguous to refer to a duplicated entry though. assert alias_index is not None - raise KeyError( - f"Found duplicate entries for '{key}' in {alt_repr}.{alt_search_repr}." - ) + msg = f"Found duplicate entries for {key!r} in {alt_repr}.{alt_search_repr}." + raise KeyError(msg) index_keys.append(val) index_aliases.append(key) else: not_found.append(key) if len(not_found) > 0: - raise KeyError( - f"Could not find keys '{not_found}' in columns of `adata.{dim}` or in" + msg = ( + f"Could not find keys {not_found!r} in columns of `adata.{dim}` or in" f" {alt_repr}.{alt_search_repr}." ) + raise KeyError(msg) return col_keys, index_keys, index_aliases @@ -286,9 +287,9 @@ def obs_df( if isinstance(keys, str): keys = [keys] if use_raw: - assert ( - layer is None - ), "Cannot specify use_raw=True and a layer at the same time." + assert layer is None, ( + "Cannot specify use_raw=True and a layer at the same time." + ) var = adata.raw.var else: var = adata.var @@ -430,7 +431,8 @@ def _get_obs_rep( """ # https://github.com/scverse/scanpy/issues/1546 if not isinstance(use_raw, bool): - raise TypeError(f"use_raw expected to be bool, was {type(use_raw)}.") + msg = f"use_raw expected to be bool, was {type(use_raw)}." + raise TypeError(msg) is_layer = layer is not None is_raw = use_raw is not False @@ -448,10 +450,11 @@ def _get_obs_rep( return adata.obsm[obsm] if is_obsp: return adata.obsp[obsp] - raise AssertionError( + msg = ( "That was unexpected. Please report this bug at:\n\n\t" "https://github.com/scverse/scanpy/issues" ) + raise AssertionError(msg) def _set_obs_rep( diff --git a/src/scanpy/logging.py b/src/scanpy/logging.py index 3aa0ca494c..7bd678f568 100644 --- a/src/scanpy/logging.py +++ b/src/scanpy/logging.py @@ -181,7 +181,7 @@ def print_version_and_date(*, file=None): if file is None: file = sys.stdout print( - f"Running Scanpy {__version__}, " f"on {datetime.now():%Y-%m-%d %H:%M}.", + f"Running Scanpy {__version__}, on {datetime.now():%Y-%m-%d %H:%M}.", file=file, ) diff --git a/src/scanpy/metrics/_gearys_c.py b/src/scanpy/metrics/_gearys_c.py index 358a201eed..cf4220eb7a 100644 --- a/src/scanpy/metrics/_gearys_c.py +++ b/src/scanpy/metrics/_gearys_c.py @@ -113,7 +113,8 @@ def gearys_c( elif "neighbors" in adata.uns: g = adata.uns["neighbors"]["connectivities"] else: - raise ValueError("Must run neighbors first.") + msg = "Must run neighbors first." + raise ValueError(msg) else: raise NotImplementedError() if vals is None: diff --git a/src/scanpy/metrics/_morans_i.py b/src/scanpy/metrics/_morans_i.py index 5e4ab50788..c21c455f38 100644 --- a/src/scanpy/metrics/_morans_i.py +++ b/src/scanpy/metrics/_morans_i.py @@ -112,7 +112,8 @@ def morans_i( elif "neighbors" in adata.uns: g = adata.uns["neighbors"]["connectivities"] else: - raise ValueError("Must run neighbors first.") + msg = "Must run neighbors first." + raise ValueError(msg) else: raise NotImplementedError() if vals is None: diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index ec5957b325..214043727b 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -425,10 +425,11 @@ def count_nonzero(a: np.ndarray | csr_matrix) -> int: self._eigen_basis = _backwards_compat_get_full_X_diffmap(adata) if n_dcs is not None: if n_dcs > len(self._eigen_values): - raise ValueError( + msg = ( f"Cannot instantiate using `n_dcs`={n_dcs}. " "Compute diffmap/spectrum with more components first." ) + raise ValueError(msg) self._eigen_values = self._eigen_values[:n_dcs] self._eigen_basis = self._eigen_basis[:, :n_dcs] self.n_dcs = len(self._eigen_values) @@ -789,7 +790,8 @@ def compute_eigen( """ np.set_printoptions(precision=10) if self._transitions_sym is None: - raise ValueError("Run `.compute_transitions` first.") + msg = "Run `.compute_transitions` first." + raise ValueError(msg) matrix = self._transitions_sym # compute the spectrum if n_comps == 0: @@ -812,9 +814,7 @@ def compute_eigen( if sort == "decrease": evals = evals[::-1] evecs = evecs[:, ::-1] - logg.info( - f" eigenvalues of transition matrix\n" f"{indent(str(evals), ' ')}" - ) + logg.info(f" eigenvalues of transition matrix\n{indent(str(evals), ' ')}") if self._number_connected_components > len(evals) / 2: logg.warning("Transition matrix has many disconnected components!") self._eigen_values = evals @@ -825,10 +825,11 @@ def _init_iroot(self): # set iroot directly if "iroot" in self._adata.uns: if self._adata.uns["iroot"] >= self._adata.n_obs: - logg.warning( - f'Root cell index {self._adata.uns["iroot"]} does not ' + msg = ( + f"Root cell index {self._adata.uns['iroot']} does not " f"exist for {self._adata.n_obs} samples. It’s ignored." ) + logg.warning(msg) else: self.iroot = self._adata.uns["iroot"] return @@ -890,9 +891,8 @@ def _set_iroot_via_xroot(self, xroot: np.ndarray): condition, only relevant for computing pseudotime. """ if self._adata.shape[1] != xroot.size: - raise ValueError( - "The root vector you provided does not have the " "correct dimension." - ) + msg = "The root vector you provided does not have the correct dimension." + raise ValueError(msg) # this is the squared distance dsqroot = 1e10 iroot = 0 diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index a93d55699b..75dd210c0b 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -163,7 +163,8 @@ def scatter( if basis is not None: return _scatter_obs(**args) if x is None or y is None: - raise ValueError("Either provide a `basis` or `x` and `y`.") + msg = "Either provide a `basis` or `x` and `y`." + raise ValueError(msg) if _check_if_annotations(adata, "obs", x=x, y=y, colors=color, use_raw=use_raw): return _scatter_obs(**args) if _check_if_annotations(adata, "var", x=x, y=y, colors=color, use_raw=use_raw): @@ -172,10 +173,11 @@ def scatter( # store .uns annotations that were added to the new adata object adata.uns = args_t["adata"].uns return axs - raise ValueError( + msg = ( "`x`, `y`, and potential `color` inputs must all " "come from either `.obs` or `.var`" ) + raise ValueError(msg) def _check_if_annotations( @@ -259,22 +261,23 @@ def _scatter_obs( layers = tuple(layers) for layer in layers: if layer not in adata.layers and layer not in ["X", None]: - raise ValueError( + msg = ( "`layers` should have elements that are " "either None or in adata.layers.keys()." ) + raise ValueError(msg) else: - raise ValueError( + msg = ( "`layers` should be a string or a collection of strings " f"with length 3, had value '{layers}'" ) + raise ValueError(msg) if use_raw and layers not in [("X", "X", "X"), (None, None, None)]: ValueError("`use_raw` must be `False` if layers are used.") if legend_loc not in (valid_legend_locs := get_literal_vals(_utils._LegendLoc)): - raise ValueError( - f"Invalid `legend_loc`, need to be one of: {valid_legend_locs}." - ) + msg = f"Invalid `legend_loc`, need to be one of: {valid_legend_locs}." + raise ValueError(msg) if components is None: components = "1,2" if "2d" in projection else "1,2,3" if isinstance(components, str): @@ -294,9 +297,8 @@ def _scatter_obs( if basis == "diffmap": components -= 1 except KeyError: - raise KeyError( - f"compute coordinates using visualization tool {basis} first" - ) + msg = f"compute coordinates using visualization tool {basis} first" + raise KeyError(msg) elif x is not None and y is not None: if use_raw: if x in adata.obs.columns: @@ -313,7 +315,8 @@ def _scatter_obs( Y = np.c_[x_arr, y_arr] else: - raise ValueError("Either provide a `basis` or `x` and `y`.") + msg = "Either provide a `basis` or `x` and `y`." + raise ValueError(msg) if size is None: n = Y.shape[0] @@ -375,10 +378,11 @@ def _scatter_obs( c = key colorbar = False else: - raise ValueError( + msg = ( f"key {key!r} is invalid! pass valid observation annotation, " f"one of {adata.obs_keys()} or a gene name {adata.var_names}" ) + raise ValueError(msg) if colorbar is None: colorbar = not categorical colorbars.append(colorbar) @@ -451,10 +455,11 @@ def add_centroid(centroids, name, Y, mask): groups = [groups] if isinstance(groups, str) else groups for name in groups: if name not in set(adata.obs[key].cat.categories): - raise ValueError( + msg = ( f"{name!r} is invalid! specify valid name, " f"one of {adata.obs[key].cat.categories}" ) + raise ValueError(msg) else: iname = np.flatnonzero( adata.obs[key].cat.categories.values == name @@ -844,23 +849,21 @@ def violin( ylabel = [ylabel] * (1 if groupby is None else len(keys)) if groupby is None: if len(ylabel) != 1: - raise ValueError( - f"Expected number of y-labels to be `1`, found `{len(ylabel)}`." - ) + msg = f"Expected number of y-labels to be `1`, found `{len(ylabel)}`." + raise ValueError(msg) elif len(ylabel) != len(keys): - raise ValueError( - f"Expected number of y-labels to be `{len(keys)}`, " - f"found `{len(ylabel)}`." - ) + msg = f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`." + raise ValueError(msg) if groupby is not None: obs_df = get.obs_df(adata, keys=[groupby] + keys, layer=layer, use_raw=use_raw) if kwds.get("palette") is None: if not isinstance(adata.obs[groupby].dtype, CategoricalDtype): - raise ValueError( + msg = ( f"The column `adata.obs[{groupby!r}]` needs to be categorical, " f"but is of dtype {adata.obs[groupby].dtype}." ) + raise ValueError(msg) _utils.add_colors_for_categorical_sample_annotation(adata, groupby) kwds["hue"] = groupby kwds["palette"] = dict( @@ -1022,7 +1025,8 @@ def clustermap( import seaborn as sns # Slow import, only import if called if not isinstance(obs_keys, str | NoneType): - raise ValueError("Currently, only a single key is supported.") + msg = "Currently, only a single key is supported." + raise ValueError(msg) sanitize_anndata(adata) use_raw = _check_use_raw(adata, use_raw) X = adata.raw.X if use_raw else adata.X @@ -1555,11 +1559,12 @@ def tracksplot( """ if groupby not in adata.obs_keys() or adata.obs[groupby].dtype.name != "category": - raise ValueError( + msg = ( "groupby has to be a valid categorical observation. " f"Given value: {groupby}, valid categorical observations: " - f'{[x for x in adata.obs_keys() if adata.obs[x].dtype.name == "category"]}' + f"{[x for x in adata.obs_keys() if adata.obs[x].dtype.name == 'category']}" ) + raise ValueError(msg) var_names, var_group_labels, var_group_positions = _check_var_names_type( var_names, var_group_labels, var_group_positions @@ -1891,7 +1896,8 @@ def correlation_matrix( dendrogram = ax is None if dendrogram: if ax is not None: - raise ValueError("Can only plot dendrogram when not plotting to an axis") + msg = "Can only plot dendrogram when not plotting to an axis" + raise ValueError(msg) assert (len(index)) == corr_matrix.shape[0] corr_matrix = corr_matrix[index, :] corr_matrix = corr_matrix[:, index] @@ -2059,10 +2065,11 @@ def _prepare_dataframe( f"Given {group}, is not in observations: {adata.obs_keys()}" + msg ) if group in adata.obs.columns and group == adata.obs.index.name: - raise ValueError( + msg = ( f"Given group {group} is both and index and a column level, " "which is ambiguous." ) + raise ValueError(msg) if group == adata.obs.index.name: groupby_index = group if groupby_index is not None: @@ -2277,19 +2284,12 @@ def _reorder_categories_after_dendrogram( 'var_group_labels', and 'var_group_positions' """ - dendrogram_key = _get_dendrogram_key(adata, dendrogram_key, groupby) - if isinstance(groupby, str): groupby = [groupby] - dendro_info = adata.uns[dendrogram_key] - if groupby != dendro_info["groupby"]: - raise ValueError( - "Incompatible observations. The precomputed dendrogram contains " - f"information for the observation: '{groupby}' while the plot is " - f"made for the observation: '{dendro_info['groupby']}. " - "Please run `sc.tl.dendrogram` using the right observation.'" - ) + dendro_info = adata.uns[ + _get_dendrogram_key(adata, dendrogram_key, groupby, validate_groupby=True) + ] if categories is None: categories = adata.obs[dendro_info["groupby"]].cat.categories @@ -2299,7 +2299,7 @@ def _reorder_categories_after_dendrogram( categories_ordered = dendro_info["categories_ordered"] if len(categories) != len(categories_idx_ordered): - raise ValueError( + msg = ( "Incompatible observations. Dendrogram data has " f"{len(categories_idx_ordered)} categories but current groupby " f"observation {groupby!r} contains {len(categories)} categories. " @@ -2307,6 +2307,7 @@ def _reorder_categories_after_dendrogram( "initial computation of `sc.tl.dendrogram`. " "Please run `sc.tl.dendrogram` again.'" ) + raise ValueError(msg) # reorder var_groups (if any) if var_group_positions is None or var_group_labels is None: @@ -2362,7 +2363,11 @@ def _format_first_three_categories(categories): def _get_dendrogram_key( - adata: AnnData, dendrogram_key: str | None, groupby: str | Sequence[str] + adata: AnnData, + dendrogram_key: str | None, + groupby: str | Sequence[str], + *, + validate_groupby: bool = False, ) -> str: # the `dendrogram_key` can be a bool an NoneType or the name of the # dendrogram key. By default the name of the dendrogram key is 'dendrogram' @@ -2370,7 +2375,7 @@ def _get_dendrogram_key( if isinstance(groupby, str): dendrogram_key = f"dendrogram_{groupby}" elif isinstance(groupby, Sequence): - dendrogram_key = f'dendrogram_{"_".join(groupby)}' + dendrogram_key = f"dendrogram_{'_'.join(groupby)}" else: msg = f"groupby has wrong type: {type(groupby).__name__}." raise AssertionError(msg) @@ -2386,10 +2391,22 @@ def _get_dendrogram_key( dendrogram(adata, groupby, key_added=dendrogram_key) if "dendrogram_info" not in adata.uns[dendrogram_key]: - raise ValueError( + msg = ( f"The given dendrogram key ({dendrogram_key!r}) does not contain " "valid dendrogram information." ) + raise ValueError(msg) + + if validate_groupby: + existing_groupby = adata.uns[dendrogram_key]["groupby"] + if groupby != existing_groupby: + msg = ( + "Incompatible observations. The precomputed dendrogram contains " + f"information for the observation: {groupby!r} while the plot is " + f"made for the observation: {existing_groupby!r}. " + "Please run `sc.tl.dendrogram` using the right observation.'" + ) + raise ValueError(msg) return dendrogram_key diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index fff1b40322..e14d387f84 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -899,23 +899,18 @@ def _format_first_three_categories(_categories): _categories = _categories[:3] + ["etc."] return ", ".join(_categories) - key = _get_dendrogram_key(self.adata, dendrogram_key, self.groupby) - - dendro_info = self.adata.uns[key] - if self.groupby != dendro_info["groupby"]: - raise ValueError( - "Incompatible observations. The precomputed dendrogram contains " - f"information for the observation: '{self.groupby}' while the plot is " - f"made for the observation: '{dendro_info['groupby']}. " - "Please run `sc.tl.dendrogram` using the right observation.'" + dendro_info = self.adata.uns[ + _get_dendrogram_key( + self.adata, dendrogram_key, self.groupby, validate_groupby=True ) + ] # order of groupby categories categories_idx_ordered = dendro_info["categories_idx_ordered"] categories_ordered = dendro_info["categories_ordered"] if len(self.categories) != len(categories_idx_ordered): - raise ValueError( + msg = ( "Incompatible observations. Dendrogram data has " f"{len(categories_idx_ordered)} categories but current groupby " f"observation {self.groupby!r} contains {len(self.categories)} categories. " @@ -923,6 +918,7 @@ def _format_first_three_categories(_categories): "initial computation of `sc.tl.dendrogram`. " "Please run `sc.tl.dendrogram` again.'" ) + raise ValueError(msg) # reorder var_groups (if any) if self.var_names is not None: diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index e2ae434db6..da3d16379b 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -681,11 +681,11 @@ def _dotplot( """ assert dot_size.shape == dot_color.shape, ( - "please check that dot_size " "and dot_color dataframes have the same shape" + "please check that dot_size and dot_color dataframes have the same shape" ) assert list(dot_size.index) == list(dot_color.index), ( - "please check that dot_size " "and dot_color dataframes have the same index" + "please check that dot_size and dot_color dataframes have the same index" ) assert list(dot_size.columns) == list(dot_color.columns), ( @@ -721,12 +721,14 @@ def _dotplot( dot_max = np.ceil(max(frac) * 10) / 10 else: if dot_max < 0 or dot_max > 1: - raise ValueError("`dot_max` value has to be between 0 and 1") + msg = "`dot_max` value has to be between 0 and 1" + raise ValueError(msg) if dot_min is None: dot_min = 0 else: if dot_min < 0 or dot_min > 1: - raise ValueError("`dot_min` value has to be between 0 and 1") + msg = "`dot_min` value has to be between 0 and 1" + raise ValueError(msg) if dot_min != 0 or dot_max != 1: # clip frac between dot_min and dot_max diff --git a/src/scanpy/plotting/_scrublet.py b/src/scanpy/plotting/_scrublet.py index 4a1247574d..050aec6f53 100644 --- a/src/scanpy/plotting/_scrublet.py +++ b/src/scanpy/plotting/_scrublet.py @@ -72,9 +72,8 @@ def scrublet_score_distribution( """ if "scrublet" not in adata.uns: - raise ValueError( - "Please run scrublet before trying to generate the scrublet plot." - ) + msg = "Please run scrublet before trying to generate the scrublet plot." + raise ValueError(msg) # If batched_by is populated, then we know Scrublet was run over multiple batches diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index e47680facc..3c58ead35f 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -750,7 +750,7 @@ def stacked_violin( e.g. `'red'` or `'#cc33ff'`. {show_save_ax} {vminmax} - kwds + **kwds Are passed to :func:`~seaborn.violinplot`. Returns diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index a421f6b94a..8f189121e2 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -158,14 +158,14 @@ def pca_loadings( components = np.array(components) - 1 if np.any(components < 0): - raise ValueError("Component indices must be greater than zero.") + msg = "Component indices must be greater than zero." + raise ValueError(msg) if n_points is None: n_points = min(30, adata.n_vars) elif adata.n_vars < n_points: - raise ValueError( - f"Tried to plot {n_points} variables, but passed anndata only has {adata.n_vars}." - ) + msg = f"Tried to plot {n_points} variables, but passed anndata only has {adata.n_vars}." + raise ValueError(msg) ranking( adata, @@ -398,10 +398,11 @@ def rank_genes_groups( """ n_panels_per_row = kwds.get("n_panels_per_row", ncols) if n_genes < 1: - raise NotImplementedError( + msg = ( "Specifying a negative number for n_genes has not been implemented for " - f"this plot. Received n_genes={n_genes}." + f"this plot. Received {n_genes=!r}." ) + raise NotImplementedError(msg) reference = str(adata.uns[key]["params"]["reference"]) group_names = adata.uns[key]["names"].dtype.names if groups is None else groups @@ -517,10 +518,11 @@ def _rank_genes_groups_plot( Common function to call the different rank_genes_groups_* plots """ if var_names is not None and n_genes is not None: - raise ValueError( + msg = ( "The arguments n_genes and var_names are mutually exclusive. Please " "select only one." ) + raise ValueError(msg) if var_names is None and n_genes is None: # set n_genes = 10 as default when none of the options is given @@ -694,7 +696,6 @@ def rank_genes_groups_heatmap( {show_save_ax} **kwds Are passed to :func:`~scanpy.pl.heatmap`. - {show_save_ax} Examples -------- @@ -778,7 +779,6 @@ def rank_genes_groups_tracksplot( {show_save_ax} **kwds Are passed to :func:`~scanpy.pl.tracksplot`. - {show_save_ax} Examples -------- @@ -1313,9 +1313,7 @@ def rank_genes_groups_violin( _ax.set_ylabel("expression") _ax.set_xticklabels(new_gene_names, rotation="vertical") writekey = ( - f"rank_genes_groups_" - f"{adata.uns[key]['params']['groupby']}_" - f"{group_name}" + f"rank_genes_groups_{adata.uns[key]['params']['groupby']}_{group_name}" ) savefig_or_show(writekey, show=show, save=save) axs.append(_ax) @@ -1527,7 +1525,8 @@ def embedding_density( basis = "draw_graph_fa" if key is not None and groupby is not None: - raise ValueError("either pass key or groupby but not both") + msg = "either pass key or groupby but not both" + raise ValueError(msg) if key is None: key = "umap_density" @@ -1535,16 +1534,17 @@ def embedding_density( key += f"_{groupby}" if f"X_{basis}" not in adata.obsm_keys(): - raise ValueError( - f"Cannot find the embedded representation `adata.obsm[X_{basis!r}]`. " + msg = ( + f"Cannot find the embedded representation `adata.obsm['X_{basis}']`. " "Compute the embedding first." ) + raise ValueError(msg) if key not in adata.obs or f"{key}_params" not in adata.uns: - raise ValueError( - "Please run `sc.tl.embedding_density()` first " - "and specify the correct key." + msg = ( + "Please run `sc.tl.embedding_density()` first and specify the correct key." ) + raise ValueError(msg) if "components" in kwargs: logg.warning( @@ -1563,10 +1563,11 @@ def embedding_density( group = [group] if group is None and groupby is not None: - raise ValueError( + msg = ( "Densities were calculated over an `.obs` covariate. " "Please specify a group from this covariate to plot." ) + raise ValueError(msg) if group is not None and groupby is None: logg.warning( @@ -1576,7 +1577,8 @@ def embedding_density( group = None if np.min(adata.obs[key]) < 0 or np.max(adata.obs[key]) > 1: - raise ValueError("Densities should be scaled between 0 and 1.") + msg = "Densities should be scaled between 0 and 1." + raise ValueError(msg) if wspace is None: # try to set a wspace that is not too large or too small given the @@ -1601,17 +1603,19 @@ def embedding_density( # (even if only one group is set) if group is not None and not isinstance(group, str) and isinstance(group, Sequence): if ax is not None: - raise ValueError("Can only specify `ax` if no `group` sequence is given.") + msg = "Can only specify `ax` if no `group` sequence is given." + raise ValueError(msg) fig, gs = _panel_grid(hspace, wspace, ncols, len(group)) axs = [] for count, group_name in enumerate(group): if group_name not in adata.obs[groupby].cat.categories: - raise ValueError( + msg = ( "Please specify a group from the `.obs` category " "over which the density was calculated. " f"Invalid group name: {group_name}" ) + raise ValueError(msg) ax = plt.subplot(gs[count]) # Define plotting data @@ -1743,9 +1747,8 @@ def _get_values_to_plot( "log10_pvals_adj", ] if values_to_plot not in valid_options: - raise ValueError( - f"given value_to_plot: '{values_to_plot}' is not valid. Valid options are {valid_options}" - ) + msg = f"given value_to_plot: '{values_to_plot}' is not valid. Valid options are {valid_options}" + raise ValueError(msg) values_df = None check_done = False diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index e67e6e2ece..a4b2de3441 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -239,10 +239,11 @@ def _compute_pos( nx_g_tree = nx.Graph(adj_tree) pos = _utils.hierarchy_pos(nx_g_tree, root) if len(pos) < adjacency_solid.shape[0]: - raise ValueError( + msg = ( "This is a forest and not a single tree. " "Try another `layout`, e.g., {'fr'}." ) + raise ValueError(msg) else: # igraph layouts random.seed(random_state.bytes(8)) @@ -547,10 +548,8 @@ def is_flat(x): if isinstance(root, str): if root not in labels: - raise ValueError( - "If `root` is a string, " - f"it needs to be one of {labels} not {root!r}." - ) + msg = f"If `root` is a string, it needs to be one of {labels} not {root!r}." + raise ValueError(msg) root = list(labels).index(root) if isinstance(root, Sequence) and root[0] in labels: root = [list(labels).index(r) for r in root] @@ -731,10 +730,11 @@ def _paga_graph( else: pos = Path(pos) if pos.suffix != ".gdf": - raise ValueError( + msg = ( "Currently only supporting reading positions from .gdf files. " "Consider generating them using, for instance, Gephi." ) + raise ValueError(msg) s = "" # read the node definition from the file with pos.open() as f: f.readline() @@ -762,7 +762,8 @@ def _paga_graph( elif colors == "degree_solid": colors = [d for _, d in nx_g_solid.degree(weight="weight")] else: - raise ValueError('`degree` either "degree_dashed" or "degree_solid".') + msg = '`degree` either "degree_dashed" or "degree_solid".' + raise ValueError(msg) colors = (np.array(colors) - np.min(colors)) / (np.max(colors) - np.min(colors)) # plot gene expression @@ -811,10 +812,11 @@ def _paga_graph( colors = asso_colors if len(colors) != len(node_labels): - raise ValueError( + msg = ( f"Expected `colors` to be of length `{len(node_labels)}`, " f"found `{len(colors)}`." ) + raise ValueError(msg) # count number of connected components n_components, labels = scipy.sparse.csgraph.connected_components(adjacency_solid) @@ -839,7 +841,8 @@ def _paga_graph( ) nx_g_solid = nx.Graph(adjacency_solid) if dashed_edges is not None: - raise ValueError("`single_component` only if `dashed_edges` is `None`.") + msg = "`single_component` only if `dashed_edges` is `None`." + raise ValueError(msg) # edge widths base_edge_width = edge_width_scale * 5 * rcParams["lines.linewidth"] @@ -958,10 +961,11 @@ def _paga_graph( else: for ix, (xx, yy) in enumerate(zip(pos_array[:, 0], pos_array[:, 1])): if not isinstance(colors[ix], Mapping): - raise ValueError( + msg = ( f"{colors[ix]} is neither a dict of valid " "matplotlib colors nor a valid matplotlib color." ) + raise ValueError(msg) color_single = colors[ix].keys() fracs = [colors[ix][c] for c in color_single] total = sum(fracs) @@ -971,10 +975,11 @@ def _paga_graph( color_single.append("grey") fracs.append(1 - sum(fracs)) elif not np.isclose(total, 1): - raise ValueError( + msg = ( f"Expected fractions for node `{ix}` to be " f"close to 1, found `{total}`." ) + raise ValueError(msg) cumsum = np.cumsum(fracs) cumsum = cumsum / cumsum[-1] @@ -1125,18 +1130,20 @@ def paga_path( if groups_key is None: if "groups" not in adata.uns["paga"]: - raise KeyError( + msg = ( "Pass the key of the grouping with which you ran PAGA, " "using the parameter `groups_key`." ) + raise KeyError(msg) groups_key = adata.uns["paga"]["groups"] groups_names = adata.obs[groups_key].cat.categories if "dpt_pseudotime" not in adata.obs.columns: - raise ValueError( + msg = ( "`pl.paga_path` requires computation of a pseudotime `tl.dpt` " "for ordering at single-cell resolution" ) + raise ValueError(msg) if palette_groups is None: _utils.add_colors_for_categorical_sample_annotation(adata, groups_key) @@ -1157,10 +1164,11 @@ def moving_average(a): groups_names_set = set(groups_names) for node in nodes: if node not in groups_names_set: - raise ValueError( + msg = ( f"Each node/group needs to be in {groups_names.tolist()} " - f"(`groups_key`={groups_key!r}) not {node!r}." + f"({groups_key=!r}) not {node!r}." ) + raise ValueError(msg) nodes_ints.append(groups_names.get_loc(node)) nodes_strs = nodes else: @@ -1178,12 +1186,13 @@ def moving_average(a): adata.obs[groups_key].values == nodes_strs[igroup] ] if len(idcs) == 0: - raise ValueError( + msg = ( "Did not find data points that match " f"`adata.obs[{groups_key!r}].values == {str(group)!r}`. " f"Check whether `adata.obs[{groups_key!r}]` " "actually contains what you expect." ) + raise ValueError(msg) idcs_group = np.argsort( adata.obs["dpt_pseudotime"].values[ adata.obs[groups_key].values == nodes_strs[igroup] diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index b54897678f..cb3c9d7c66 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -149,7 +149,8 @@ def embedding( # Checking the mask format and if used together with groups if groups is not None and mask_obs is not None: - raise ValueError("Groups and mask arguments are incompatible.") + msg = "Groups and mask arguments are incompatible." + raise ValueError(msg) mask_obs = _check_mask(adata, mask_obs, "obs") # Figure out if we're using raw @@ -157,15 +158,17 @@ def embedding( # check if adata.raw is set use_raw = layer is None and adata.raw is not None if use_raw and layer is not None: - raise ValueError( - "Cannot use both a layer and the raw representation. Was passed:" - f"use_raw={use_raw}, layer={layer}." + msg = ( + "Cannot use both a layer and the raw representation. " + f"Was passed: {use_raw=!r}, {layer=!r}." ) + raise ValueError(msg) if use_raw and adata.raw is None: - raise ValueError( + msg = ( "`use_raw` is set to True but AnnData object does not have raw. " "Please check." ) + raise ValueError(msg) if isinstance(groups, str): groups = [groups] @@ -173,7 +176,8 @@ def embedding( # Color map if color_map is not None: if cmap is not None: - raise ValueError("Cannot specify both `color_map` and `cmap`.") + msg = "Cannot specify both `color_map` and `cmap`." + raise ValueError(msg) else: cmap = color_map cmap = copy(colormaps.get_cmap(cmap)) @@ -245,10 +249,11 @@ def embedding( not isinstance(color, str) and isinstance(color, Sequence) and len(color) > 1 ) or len(dimensions) > 1: if ax is not None: - raise ValueError( + msg = ( "Cannot specify `ax` when plotting multiple panels " "(each for a given value of 'color')." ) + raise ValueError(msg) # each plot needs to be its own panel fig, grid = _panel_grid(hspace, wspace, ncols, len(color)) @@ -810,9 +815,8 @@ def draw_graph( layout = str(adata.uns["draw_graph"]["params"]["layout"]) basis = f"draw_graph_{layout}" if f"X_{basis}" not in adata.obsm_keys(): - raise ValueError( - f"Did not find {basis} in adata.obs. Did you compute layout {layout}?" - ) + msg = f"Did not find {basis} in adata.obs. Did you compute layout {layout}?" + raise ValueError(msg) return embedding(adata, basis, **kwargs) @@ -883,10 +887,11 @@ def pca( adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs ) if "pca" not in adata.obsm and "X_pca" not in adata.obsm: - raise KeyError( + msg = ( f"Could not find entry in `obsm` for 'pca'.\n" f"Available keys are: {list(adata.obsm.keys())}." ) + raise KeyError(msg) label_dict = { f"PC{i + 1}": f"PC{i + 1} ({round(v * 100, 2)}%)" @@ -1060,7 +1065,8 @@ def _components_to_dimensions( if components is None and dimensions is None: dimensions = [tuple(i for i in range(ndims))] elif components is not None and dimensions is not None: - raise ValueError("Cannot provide both dimensions and components") + msg = "Cannot provide both dimensions and components" + raise ValueError(msg) # TODO: Consider deprecating this # If components is not None, parse them and set dimensions @@ -1099,9 +1105,8 @@ def _add_categorical_legend( """Add a legend to the passed Axes.""" if na_in_legend and pd.isnull(color_source_vector).any(): if "NA" in color_source_vector: - raise NotImplementedError( - "No fallback for null labels has been defined if NA already in categories." - ) + msg = "No fallback for null labels has been defined if NA already in categories." + raise NotImplementedError(msg) color_source_vector = color_source_vector.add_categories("NA").fillna("NA") palette = palette.copy() palette["NA"] = na_color @@ -1162,7 +1167,8 @@ def _get_basis(adata: AnnData, basis: str) -> np.ndarray: elif f"X_{basis}" in adata.obsm: return adata.obsm[f"X_{basis}"] else: - raise KeyError(f"Could not find '{basis}' or 'X_{basis}' in .obsm") + msg = f"Could not find {basis!r} or 'X_{basis}' in .obsm" + raise KeyError(msg) def _get_color_source_vector( @@ -1294,10 +1300,11 @@ def _check_spot_size(spatial_data: Mapping | None, spot_size: float | None) -> f This is a required argument for spatial plots. """ if spatial_data is None and spot_size is None: - raise ValueError( + msg = ( "When .uns['spatial'][library_id] does not exist, spot_size must be " "provided directly." ) + raise ValueError(msg) elif spot_size is None: return spatial_data["scalefactors"]["spot_diameter_fullres"] else: @@ -1329,10 +1336,11 @@ def _check_spatial_data( spatial_mapping = uns.get("spatial", {}) if library_id is _empty: if len(spatial_mapping) > 1: - raise ValueError( + msg = ( "Found multiple possible libraries in `.uns['spatial']. Please specify." f" Options are:\n\t{list(spatial_mapping.keys())}" ) + raise ValueError(msg) elif len(spatial_mapping) == 1: library_id = list(spatial_mapping.keys())[0] else: @@ -1370,7 +1378,8 @@ def _check_crop_coord( if crop_coord is None: return None if len(crop_coord) != 4: - raise ValueError("Invalid crop_coord of length {len(crop_coord)}(!=4)") + msg = "Invalid crop_coord of length {len(crop_coord)}(!=4)" + raise ValueError(msg) crop_coord = tuple(c * scale_factor for c in crop_coord) return crop_coord @@ -1389,7 +1398,8 @@ def _broadcast_args(*args): lens = [len(arg) for arg in args] longest = max(lens) if not (set(lens) == {1, longest} or set(lens) == {longest}): - raise ValueError(f"Could not broadcast together arguments with shapes: {lens}.") + msg = f"Could not broadcast together arguments with shapes: {lens}." + raise ValueError(msg) return list( [[arg[0] for _ in range(longest)] if len(arg) == 1 else arg for arg in args] ) diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index 09a01a9bc5..b6cd920039 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -398,7 +398,7 @@ def _validate_palette(adata: AnnData, key: str) -> None: else: logg.warning( f"The following color value found in adata.uns['{key}_colors'] " - f"is not valid: '{color}'. Default colors will be used instead." + f"is not valid: {color!r}. Default colors will be used instead." ) _set_default_colors_for_categorical_obs(adata, key) _palette = None @@ -466,21 +466,24 @@ def _set_colors_for_categorical_obs( if color in additional_colors: color = additional_colors[color] else: - raise ValueError( + msg = ( "The following color value of the given palette " f"is not valid: {color}" ) + raise ValueError(msg) _color_list.append(color) palette = cycler(color=_color_list) if not isinstance(palette, Cycler): - raise ValueError( + msg = ( "Please check that the value of 'palette' is a valid " "matplotlib colormap string (eg. Set2), a list of color names " "or a cycler with a 'color' key." ) + raise ValueError(msg) if "color" not in palette.keys: - raise ValueError("Please set the palette key 'color'.") + msg = "Please set the palette key 'color'." + raise ValueError(msg) cc = palette() colors_list = [to_hex(next(cc)["color"]) for x in range(len(categories))] @@ -556,7 +559,8 @@ def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=Non if neighbors_key is None: neighbors_key = "neighbors" if neighbors_key not in adata.uns: - raise ValueError("`edges=True` requires `pp.neighbors` to be run before.") + msg = "`edges=True` requires `pp.neighbors` to be run before." + raise ValueError(msg) neighbors = NeighborsView(adata, neighbors_key) g = nx.Graph(neighbors["connectivities"]) basis_key = _get_basis(adata, basis) @@ -582,11 +586,12 @@ def plot_arrows(axs, adata, basis, arrows_kwds=None): (p for p in ["velocity", "Delta"] if f"{p}_{basis}" in adata.obsm), None ) if v_prefix is None: - raise ValueError( + msg = ( "`arrows=True` requires " f"`'velocity_{basis}'` from scvelo or " f"`'Delta_{basis}'` from velocyto." ) + raise ValueError(msg) if v_prefix == "velocity": logg.warning( "The module `scvelo` has improved plotting facilities. " @@ -628,7 +633,8 @@ def scatter_group( color = rgb2hex(adata.uns[key + "_colors"][cat_code]) if not is_color_like(color): - raise ValueError(f'"{color}" is not a valid matplotlib color.') + msg = f"{color!r} is not a valid matplotlib color." + raise ValueError(msg) data = [Y[mask_obs, 0], Y[mask_obs, 1]] if projection == "3d": data.append(Y[mask_obs, 2]) @@ -658,7 +664,8 @@ def setup_axes( """Grid of axes for plotting, legends and colorbars.""" check_projection(projection) if left_margin is not None: - raise NotImplementedError("We currently don’t support `left_margin`.") + msg = "We currently don’t support `left_margin`." + raise NotImplementedError(msg) if np.any(colorbars) and right_margin is None: right_margin = 1 - rcParams["figure.subplot.right"] + 0.21 # 0.25 elif right_margin is None: @@ -801,7 +808,8 @@ def scatter_base( elif projection == "3d": data = Y_sort[:, 0], Y_sort[:, 1], Y_sort[:, 2] else: - raise ValueError(f"Unknown projection {projection!r} not in '2d', '3d'") + msg = f"Unknown projection {projection!r} not in '2d', '3d'" + raise ValueError(msg) if not isinstance(color, str) or color != "white": sct = ax.scatter( *data, @@ -1148,15 +1156,15 @@ def data_to_axis_points(ax: Axes, points_data: np.ndarray): def check_projection(projection): """Validation for projection argument.""" if projection not in {"2d", "3d"}: - raise ValueError(f"Projection must be '2d' or '3d', was '{projection}'.") + msg = f"Projection must be '2d' or '3d', was '{projection}'." + raise ValueError(msg) if projection == "3d": from packaging.version import parse mpl_version = parse(mpl.__version__) if mpl_version < parse("3.3.3"): - raise ImportError( - f"3d plotting requires matplotlib > 3.3.3. Found {mpl.__version__}" - ) + msg = f"3d plotting requires matplotlib > 3.3.3. Found {mpl.__version__}" + raise ImportError(msg) def circles( @@ -1300,7 +1308,8 @@ def check_colornorm(vmin=None, vmax=None, vcenter=None, norm=None): if norm is not None: if (vmin is not None) or (vmax is not None) or (vcenter is not None): - raise ValueError("Passing both norm and vmin/vmax/vcenter is not allowed.") + msg = "Passing both norm and vmin/vmax/vcenter is not allowed." + raise ValueError(msg) else: if vcenter is not None: norm = DivNorm(vmin=vmin, vmax=vmax, vcenter=vcenter) diff --git a/src/scanpy/preprocessing/_combat.py b/src/scanpy/preprocessing/_combat.py index caeb9a0b45..93052f356c 100644 --- a/src/scanpy/preprocessing/_combat.py +++ b/src/scanpy/preprocessing/_combat.py @@ -179,21 +179,23 @@ def combat( # check the input if key not in adata.obs_keys(): - raise ValueError(f"Could not find the key {key!r} in adata.obs") + msg = f"Could not find the key {key!r} in adata.obs" + raise ValueError(msg) if covariates is not None: cov_exist = np.isin(covariates, adata.obs_keys()) if np.any(~cov_exist): missing_cov = np.array(covariates)[~cov_exist].tolist() - raise ValueError( - f"Could not find the covariate(s) {missing_cov!r} in adata.obs" - ) + msg = f"Could not find the covariate(s) {missing_cov!r} in adata.obs" + raise ValueError(msg) if key in covariates: - raise ValueError("Batch key and covariates cannot overlap") + msg = "Batch key and covariates cannot overlap" + raise ValueError(msg) if len(covariates) != len(set(covariates)): - raise ValueError("Covariates must be unique") + msg = "Covariates must be unique" + raise ValueError(msg) # only works on dense matrices so far X = adata.X.toarray().T if issparse(adata.X) else adata.X.T diff --git a/src/scanpy/preprocessing/_deprecated/__init__.py b/src/scanpy/preprocessing/_deprecated/__init__.py index c23361631a..b821417c0b 100644 --- a/src/scanpy/preprocessing/_deprecated/__init__.py +++ b/src/scanpy/preprocessing/_deprecated/__init__.py @@ -36,7 +36,8 @@ def normalize_per_cell_weinreb16_deprecated( Normalized version of the original expression matrix. """ if max_fraction < 0 or max_fraction > 1: - raise ValueError("Choose max_fraction between 0 and 1.") + msg = "Choose max_fraction between 0 and 1." + raise ValueError(msg) counts_per_cell = x.sum(1).A1 if issparse(x) else x.sum(1) gene_subset = np.all(x <= counts_per_cell[:, None] * max_fraction, axis=0) diff --git a/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py b/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py index 27e8f1f846..bba4fb9bbf 100644 --- a/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py +++ b/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py @@ -214,7 +214,8 @@ def filter_genes_dispersion( / disp_mad_bin[df["mean_bin"].values].values ) else: - raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + msg = '`flavor` needs to be "seurat" or "cell_ranger"' + raise ValueError(msg) dispersion_norm = df["dispersion_norm"].values.astype("float32") if n_top_genes is not None: dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] @@ -268,7 +269,8 @@ def filter_genes_fano_deprecated(X, Ecutoff, Vcutoff): def _filter_genes(X, e_cutoff, v_cutoff, meth): """See `filter_genes_dispersion` :cite:p:`Weinreb2017`.""" if issparse(X): - raise ValueError("Not defined for sparse input. See `filter_genes_dispersion`.") + msg = "Not defined for sparse input. See `filter_genes_dispersion`." + raise ValueError(msg) mean_filter = np.mean(X, axis=0) > e_cutoff var_filter = meth(X, axis=0) / (np.mean(X, axis=0) + 0.0001) > v_cutoff gene_subset = np.nonzero(np.all([mean_filter, var_filter], axis=0))[0] diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index e34340b256..356fa8f03f 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -65,15 +65,14 @@ def _highly_variable_genes_seurat_v3( try: from skmisc.loess import loess except ImportError: - raise ImportError( - "Please install skmisc package via `pip install --user scikit-misc" - ) + msg = "Please install skmisc package via `pip install --user scikit-misc" + raise ImportError(msg) df = pd.DataFrame(index=adata.var_names) data = _get_obs_rep(adata, layer=layer) if check_values and not check_nonnegative_integers(data): warnings.warn( - f"`flavor='{flavor}'` expects raw count data, but non-integers were found.", + f"`{flavor=!r}` expects raw count data, but non-integers were found.", UserWarning, ) @@ -159,7 +158,8 @@ def _highly_variable_genes_seurat_v3( sort_cols = ["highly_variable_nbatches", "highly_variable_rank"] sort_ascending = [False, True] else: - raise ValueError(f"Did not recognize flavor {flavor}") + msg = f"Did not recognize flavor {flavor}" + raise ValueError(msg) sorted_index = ( df[sort_cols] .sort_values(sort_cols, ascending=sort_ascending, na_position="last") @@ -332,7 +332,8 @@ def _get_mean_bins( elif flavor == "cell_ranger": bins = np.r_[-np.inf, np.percentile(means, np.arange(10, 105, 5)), np.inf] else: - raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + msg = '`flavor` needs to be "seurat" or "cell_ranger"' + raise ValueError(msg) return pd.cut(means, bins=bins) @@ -347,7 +348,8 @@ def _get_disp_stats( elif flavor == "cell_ranger": disp_bin_stats = disp_grouped.agg(avg="median", dev=_mad) else: - raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + msg = '`flavor` needs to be "seurat" or "cell_ranger"' + raise ValueError(msg) return disp_bin_stats.loc[df["mean_bin"]].set_index(df.index) @@ -647,10 +649,11 @@ def highly_variable_genes( start = logg.info("extracting highly variable genes") if not isinstance(adata, AnnData): - raise ValueError( + msg = ( "`pp.highly_variable_genes` expects an `AnnData` argument, " "pass `inplace=False` if you want to return a `pd.DataFrame`." ) + raise ValueError(msg) if flavor in {"seurat_v3", "seurat_v3_paper"}: if n_top_genes is None: diff --git a/src/scanpy/preprocessing/_normalization.py b/src/scanpy/preprocessing/_normalization.py index c888ded9c6..e1ee3d4822 100644 --- a/src/scanpy/preprocessing/_normalization.py +++ b/src/scanpy/preprocessing/_normalization.py @@ -175,11 +175,13 @@ def normalize_total( """ if copy: if not inplace: - raise ValueError("`copy=True` cannot be used with `inplace=False`.") + msg = "`copy=True` cannot be used with `inplace=False`." + raise ValueError(msg) adata = adata.copy() if max_fraction < 0 or max_fraction > 1: - raise ValueError("Choose max_fraction between 0 and 1.") + msg = "Choose max_fraction between 0 and 1." + raise ValueError(msg) # Deprecated features if layers is not None: @@ -200,9 +202,8 @@ def normalize_total( if layers == "all": layers = adata.layers.keys() elif isinstance(layers, str): - raise ValueError( - f"`layers` needs to be a list of strings or 'all', not {layers!r}" - ) + msg = f"`layers` needs to be a list of strings or 'all', not {layers!r}" + raise ValueError(msg) view_to_actual(adata) @@ -254,7 +255,8 @@ def normalize_total( elif layer_norm is None: after = None else: - raise ValueError('layer_norm should be "after", "X" or None') + msg = 'layer_norm should be "after", "X" or None' + raise ValueError(msg) for layer_to_norm in layers if layers is not None else (): res = normalize_total( diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 3fd288ad93..db7886a29f 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -208,7 +208,8 @@ def pca( logg_start = logg.info("computing PCA") if layer is not None and chunked: # Current chunking implementation relies on pca being called on X - raise NotImplementedError("Cannot use `layer` and `chunked` at the same time.") + msg = "Cannot use `layer` and `chunked` at the same time." + raise NotImplementedError(msg) # chunked calculation is not randomized, anyways if svd_solver in {"auto", "randomized"} and not chunked: @@ -220,9 +221,8 @@ def pca( data_is_AnnData = isinstance(data, AnnData) if data_is_AnnData: if layer is None and not chunked and is_backed_type(data.X): - raise NotImplementedError( - f"PCA is not implemented for matrices of type {type(data.X)} with chunked as False" - ) + msg = f"PCA is not implemented for matrices of type {type(data.X)} with chunked as False" + raise NotImplementedError(msg) adata = data.copy() if copy else data else: if pkg_version("anndata") < Version("0.8.0rc1"): @@ -239,13 +239,12 @@ def pca( min_dim = min(adata_comp.n_vars, adata_comp.n_obs) n_comps = min_dim - 1 if min_dim <= settings.N_PCS else settings.N_PCS - logg.info(f" with n_comps={n_comps}") + logg.info(f" with {n_comps=}") X = _get_obs_rep(adata_comp, layer=layer) if is_backed_type(X) and layer is not None: - raise NotImplementedError( - f"PCA is not implemented for matrices of type {type(X)} from layers" - ) + msg = f"PCA is not implemented for matrices of type {type(X)} from layers" + raise NotImplementedError(msg) # See: https://github.com/scverse/scanpy/pull/2816#issuecomment-1932650529 if ( Version(ad.__version__) < Version("0.9") diff --git a/src/scanpy/preprocessing/_qc.py b/src/scanpy/preprocessing/_qc.py index 87ad51d420..5af8def042 100644 --- a/src/scanpy/preprocessing/_qc.py +++ b/src/scanpy/preprocessing/_qc.py @@ -32,10 +32,11 @@ def _choose_mtx_rep(adata, *, use_raw: bool = False, layer: str | None = None): is_layer = layer is not None if use_raw and is_layer: - raise ValueError( + msg = ( "Cannot use expression from both layer and raw. You provided:" - f"'use_raw={use_raw}' and 'layer={layer}'" + f"{use_raw=!r} and {layer=!r}" ) + raise ValueError(msg) if is_layer: return adata.layers[layer] elif use_raw: @@ -384,7 +385,8 @@ def top_proportions_sparse_csr(data, indptr, n): def check_ns(func): def check_ns_inner(mtx: np.ndarray | spmatrix | DaskArray, ns: Collection[int]): if not (max(ns) <= mtx.shape[1] and min(ns) > 0): - raise IndexError("Positions outside range of features.") + msg = "Positions outside range of features." + raise IndexError(msg) return func(mtx, ns) return check_ns_inner diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index 4b97405df9..4748d75e5c 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -59,7 +59,8 @@ def recipe_weinreb17( from ._deprecated import normalize_per_cell_weinreb16_deprecated, zscore_deprecated if issparse(adata.X): - raise ValueError("`recipe_weinreb16 does not support sparse matrices.") + msg = "`recipe_weinreb16 does not support sparse matrices." + raise ValueError(msg) if copy: adata = adata.copy() if log: diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index bac08f246b..ee15f977b9 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -133,13 +133,11 @@ def scale( """ _check_array_function_arguments(layer=layer, obsm=obsm) if layer is not None: - raise ValueError( - f"`layer` argument inappropriate for value of type {type(data)}" - ) + msg = f"`layer` argument inappropriate for value of type {type(data)}" + raise ValueError(msg) if obsm is not None: - raise ValueError( - f"`obsm` argument inappropriate for value of type {type(data)}" - ) + msg = f"`obsm` argument inappropriate for value of type {type(data)}" + raise ValueError(msg) return scale_array( data, zero_center=zero_center, max_value=max_value, copy=copy, mask_obs=mask_obs ) @@ -184,7 +182,7 @@ def scale_array( if not zero_center and max_value is not None: logg.info( # Be careful of what? This should be more specific - "... be careful when using `max_value` " "without `zero_center`." + "... be careful when using `max_value` without `zero_center`." ) if np.issubdtype(X.dtype, np.integer): diff --git a/src/scanpy/preprocessing/_scrublet/pipeline.py b/src/scanpy/preprocessing/_scrublet/pipeline.py index 586587e2cf..6e52a6650c 100644 --- a/src/scanpy/preprocessing/_scrublet/pipeline.py +++ b/src/scanpy/preprocessing/_scrublet/pipeline.py @@ -53,7 +53,8 @@ def truncated_svd( algorithm: Literal["arpack", "randomized"] = "arpack", ) -> None: if self._counts_sim_norm is None: - raise RuntimeError("_counts_sim_norm is not set") + msg = "_counts_sim_norm is not set" + raise RuntimeError(msg) from sklearn.decomposition import TruncatedSVD svd = TruncatedSVD( @@ -72,7 +73,8 @@ def pca( svd_solver: Literal["auto", "full", "arpack", "randomized"] = "arpack", ) -> None: if self._counts_sim_norm is None: - raise RuntimeError("_counts_sim_norm is not set") + msg = "_counts_sim_norm is not set" + raise RuntimeError(msg) from sklearn.decomposition import PCA X_obs = self._counts_obs_norm.toarray() diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index ac68edd376..fda79b4da2 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -146,10 +146,11 @@ def filter_cells( option is not None for option in [min_genes, min_counts, max_genes, max_counts] ) if n_given_options != 1: - raise ValueError( + msg = ( "Only provide one of the optional parameters `min_counts`, " "`min_genes`, `max_counts`, `max_genes` per call." ) + raise ValueError(msg) if isinstance(data, AnnData): raise_not_implemented_error_if_backed_type(data.X, "filter_cells") adata = data.copy() if copy else data @@ -261,10 +262,11 @@ def filter_genes( option is not None for option in [min_cells, min_counts, max_cells, max_counts] ) if n_given_options != 1: - raise ValueError( + msg = ( "Only provide one of the optional parameters `min_counts`, " "`min_cells`, `max_counts`, `max_cells` per call." ) + raise ValueError(msg) if isinstance(data, AnnData): raise_not_implemented_error_if_backed_type(data.X, "filter_genes") @@ -407,13 +409,13 @@ def log1p_anndata( if chunked: if (layer is not None) or (obsm is not None): - raise NotImplementedError( + msg = ( "Currently cannot perform chunked operations on arrays not stored in X." ) + raise NotImplementedError(msg) if adata.isbacked and adata.file._filemode != "r+": - raise NotImplementedError( - "log1p is not implemented for backed AnnData with backed mode not r+" - ) + msg = "log1p is not implemented for backed AnnData with backed mode not r+" + raise NotImplementedError(msg) for chunk, start, end in adata.chunked_X(chunk_size): adata.X[start:end] = log1p(chunk, base=base, copy=False) else: @@ -421,8 +423,10 @@ def log1p_anndata( if is_backed_type(X): msg = f"log1p is not implemented for matrices of type {type(X)}" if layer is not None: - raise NotImplementedError(f"{msg} from layers") - raise NotImplementedError(f"{msg} without `chunked=True`") + msg = f"{msg} from layers" + raise NotImplementedError(msg) + msg = f"{msg} without `chunked=True`" + raise NotImplementedError(msg) X = log1p(X, copy=False, base=base) _set_obs_rep(adata, X, layer=layer, obsm=obsm) @@ -595,7 +599,8 @@ def normalize_per_cell( elif use_rep is None: after = None else: - raise ValueError('use_rep should be "after", "X" or None') + msg = 'use_rep should be "after", "X" or None' + raise ValueError(msg) for layer in layers: _subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts) temp = normalize_per_cell(adata.layers[layer], after, counts, copy=True) @@ -611,7 +616,8 @@ def normalize_per_cell( X = data.copy() if copy else data if counts_per_cell is None: if not copy: - raise ValueError("Can only be run with copy=True") + msg = "Can only be run with copy=True" + raise ValueError(msg) cell_subset, counts_per_cell = filter_cells(X, min_counts=min_counts) X = X[cell_subset] counts_per_cell = counts_per_cell[cell_subset] @@ -719,11 +725,12 @@ def regress_out( adata.obs[keys[0]].dtype, CategoricalDtype ): if len(keys) > 1: - raise ValueError( + msg = ( "If providing categorical variable, " "only a single one is allowed. For this one " "we regress on the mean for each category." ) + raise ValueError(msg) logg.debug("... regressing on per-gene means within categories") regressors = np.zeros(X.shape, dtype="float32") X = _to_dense(X, order="F") if issparse(X) else X @@ -1017,9 +1024,8 @@ def downsample_counts( total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: - raise ValueError( - "Must specify exactly one of `total_counts` or `counts_per_cell`." - ) + msg = "Must specify exactly one of `total_counts` or `counts_per_cell`." + raise ValueError(msg) if copy: adata = adata.copy() if total_counts_call: @@ -1039,11 +1045,12 @@ def _downsample_per_cell(X, counts_per_cell, random_state, replace): # np.random.choice needs int arguments in numba code: counts_per_cell = counts_per_cell.astype(np.int_, copy=False) if not isinstance(counts_per_cell, np.ndarray) or len(counts_per_cell) != n_obs: - raise ValueError( + msg = ( "If provided, 'counts_per_cell' must be either an integer, or " "coercible to an `np.ndarray` of length as number of observations" " by `np.asarray(counts_per_cell)`." ) + raise ValueError(msg) if issparse(X): original_type = type(X) if not isspmatrix_csr(X): diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index b200e89ce8..3ca74734c0 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -64,7 +64,8 @@ def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int): ax_minor = 0 shape = mtx.shape[::-1] else: - raise ValueError("This function only works on sparse csr and csc matrices") + msg = "This function only works on sparse csr and csc matrices" + raise ValueError(msg) if axis == ax_minor: return sparse_mean_var_major_axis( mtx.data, diff --git a/src/scanpy/queries/_queries.py b/src/scanpy/queries/_queries.py index 8da90151ce..e992f937e3 100644 --- a/src/scanpy/queries/_queries.py +++ b/src/scanpy/queries/_queries.py @@ -63,13 +63,13 @@ def simple_query( elif isinstance(attrs, Iterable): attrs = list(attrs) else: - raise TypeError(f"attrs must be of type list or str, was {type(attrs)}.") + msg = f"attrs must be of type list or str, was {type(attrs)}." + raise TypeError(msg) try: from pybiomart import Server except ImportError: - raise ImportError( - "This method requires the `pybiomart` module to be installed." - ) + msg = "This method requires the `pybiomart` module to be installed." + raise ImportError(msg) server = Server(host, use_cache=use_cache) dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets[f"{org}_gene_ensembl"] res = dataset.query(attributes=attrs, filters=filters, use_attr_names=True) @@ -273,17 +273,17 @@ def enrich( try: from gprofiler import GProfiler except ImportError: - raise ImportError( - "This method requires the `gprofiler-official` module to be installed." - ) + msg = "This method requires the `gprofiler-official` module to be installed." + raise ImportError(msg) gprofiler = GProfiler(user_agent="scanpy", return_dataframe=True) gprofiler_kwargs = dict(gprofiler_kwargs) for k in ["organism"]: if gprofiler_kwargs.get(k) is not None: - raise ValueError( + msg = ( f"Argument `{k}` should be passed directly through `enrich`, " "not through `gprofiler_kwargs`" ) + raise ValueError(msg) return gprofiler.profile(container, organism=org, **gprofiler_kwargs) diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 3333fbc0a1..c568519cd7 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -41,6 +41,7 @@ from ._utils import _empty if TYPE_CHECKING: + from datetime import datetime from typing import BinaryIO, Literal from ._utils import Empty @@ -155,13 +156,14 @@ def read( filekey = str(filename) filename = settings.writedir / (filekey + "." + settings.file_format_data) if not filename.exists(): - raise ValueError( + msg = ( f"Reading with filekey {filekey!r} failed, " f"the inferred filename {filename!r} does not exist. " "If you intended to provide a filename, either use a filename " f"ending on one of the available extensions {avail_exts} " "or pass the parameter `ext`." ) + raise ValueError(msg) return read_h5ad(filename, backed=backed) @@ -219,40 +221,46 @@ def read_10x_h5( adata = _read_v3_10x_h5(filename, start=start) if genome: if genome not in adata.var["genome"].values: - raise ValueError( - f"Could not find data corresponding to genome '{genome}' in '{filename}'. " - f'Available genomes are: {list(adata.var["genome"].unique())}.' + msg = ( + f"Could not find data corresponding to genome {genome!r} in {filename}. " + f"Available genomes are: {list(adata.var['genome'].unique())}." ) + raise ValueError(msg) adata = adata[:, adata.var["genome"] == genome] if gex_only: adata = adata[:, adata.var["feature_types"] == "Gene Expression"] if adata.is_view: adata = adata.copy() else: - adata = _read_legacy_10x_h5(filename, genome=genome, start=start) + adata = _read_legacy_10x_h5(Path(filename), genome=genome, start=start) return adata -def _read_legacy_10x_h5(filename, *, genome=None, start=None): +def _read_legacy_10x_h5( + path: Path, *, genome: str | None = None, start: datetime | None = None +): """ Read hdf5 file from Cell Ranger v2 or earlier versions. """ - with h5py.File(str(filename), "r") as f: + with h5py.File(str(path), "r") as f: try: children = list(f.keys()) if not genome: if len(children) > 1: - raise ValueError( - f"'{filename}' contains more than one genome. For legacy 10x h5 " - "files you must specify the genome if more than one is present. " + msg = ( + f"{path} contains more than one genome. " + "For legacy 10x h5 files you must specify the genome " + "if more than one is present. " f"Available genomes are: {children}" ) + raise ValueError(msg) genome = children[0] elif genome not in children: - raise ValueError( - f"Could not find genome '{genome}' in '{filename}'. " + msg = ( + f"Could not find genome {genome!r} in {path}. " f"Available genomes are: {children}" ) + raise ValueError(msg) dsets = {} _collect_datasets(dsets, f[genome]) @@ -283,7 +291,8 @@ def _read_legacy_10x_h5(filename, *, genome=None, start=None): logg.info("", time=start) return adata except KeyError: - raise Exception("File is missing one or more required datasets.") + msg = "File is missing one or more required datasets." + raise Exception(msg) def _collect_datasets(dsets: dict, group: h5py.Group): @@ -354,7 +363,8 @@ def _read_v3_10x_h5(filename, *, start=None): ] ) else: - raise ValueError("10x h5 has no features group") + msg = "10x h5 has no features group" + raise ValueError(msg) adata = AnnData( matrix, obs=obs_dict, @@ -363,7 +373,8 @@ def _read_v3_10x_h5(filename, *, start=None): logg.info("", time=start) return adata except KeyError: - raise Exception("File is missing one or more required datasets.") + msg = "File is missing one or more required datasets." + raise Exception(msg) @deprecated("Use `squidpy.read.visium` instead.") @@ -468,11 +479,11 @@ def read_visium( if not f.exists(): if any(x in str(f) for x in ["hires_image", "lowres_image"]): logg.warning( - f"You seem to be missing an image file.\n" - f"Could not find '{f}'." + f"You seem to be missing an image file.\nCould not find {f}." ) else: - raise OSError(f"Could not find '{f}'") + msg = f"Could not find {f}" + raise OSError(msg) adata.uns["spatial"][library_id]["images"] = dict() for res in ["hires", "lowres"]: @@ -481,7 +492,8 @@ def read_visium( str(files[f"{res}_image"]) ) except Exception: - raise OSError(f"Could not find '{res}_image'") + msg = f"Could not find '{res}_image'" + raise OSError(msg) # read json scalefactors adata.uns["spatial"][library_id]["scalefactors"] = json.loads( @@ -623,7 +635,8 @@ def _read_10x_mtx( adata.var_names = genes[0].values adata.var["gene_symbols"] = genes[1].values else: - raise ValueError("`var_names` needs to be 'gene_symbols' or 'gene_ids'") + msg = "`var_names` needs to be 'gene_symbols' or 'gene_ids'" + raise ValueError(msg) if not is_legacy: adata.var["feature_types"] = genes[2].values barcodes = pd.read_csv(path / f"{prefix}barcodes.tsv{suffix}", header=None) @@ -667,11 +680,12 @@ def write( if ext is None: ext = ext_ elif ext != ext_: - raise ValueError( + msg = ( "It suffices to provide the file type by " "providing a proper extension to the filename." 'One of "txt", "csv", "h5" or "npz".' ) + raise ValueError(msg) else: key = filename ext = settings.file_format_data if ext is None else ext @@ -767,9 +781,8 @@ def _read( **kwargs, ): if ext is not None and ext not in avail_exts: - raise ValueError( - "Please provide one of the available extensions.\n" f"{avail_exts}" - ) + msg = f"Please provide one of the available extensions.\n{avail_exts}" + raise ValueError(msg) else: ext = is_valid_filename(filename, return_ext=True) is_present = _check_datafile_present_and_download(filename, backup_url=backup_url) @@ -793,7 +806,8 @@ def _read( return read_h5ad(path_cache) if not is_present: - raise FileNotFoundError(f"Did not find file {filename}.") + msg = f"Did not find file {filename}." + raise FileNotFoundError(msg) logg.debug(f"reading {filename}") if not cache and not suppress_cache_warning: logg.hint( @@ -803,7 +817,8 @@ def _read( # do the actual reading if ext == "xlsx" or ext == "xls": if sheet is None: - raise ValueError("Provide `sheet` parameter when reading '.xlsx' files.") + msg = "Provide `sheet` parameter when reading '.xlsx' files." + raise ValueError(msg) else: adata = read_excel(filename, sheet) elif ext in {"mtx", "mtx.gz"}: @@ -817,7 +832,7 @@ def _read( elif ext in {"txt", "tab", "data", "tsv"}: if ext == "data": logg.hint( - "... assuming '.data' means tab or white-space " "separated text file", + "... assuming '.data' means tab or white-space separated text file" ) logg.hint("change this by passing `ext` to sc.read") adata = read_text(filename, delimiter, first_column_names) @@ -826,7 +841,8 @@ def _read( elif ext == "loom": adata = read_loom(filename=filename, **kwargs) else: - raise ValueError(f"Unknown extension {ext}.") + msg = f"Unknown extension {ext}." + raise ValueError(msg) if cache: logg.info( f"... writing an {settings.file_format_data} " @@ -1091,11 +1107,10 @@ def is_valid_filename(filename: Path, *, return_ext: bool = False): return "mtx.gz" if return_ext else True elif not return_ext: return False - raise ValueError( - f"""\ + msg = f"""\ {filename!r} does not end on a valid extension. Please, provide one of the available extensions. {avail_exts} Text files with .gz and .bz2 extensions are also supported.\ """ - ) + raise ValueError(msg) diff --git a/src/scanpy/tools/_dendrogram.py b/src/scanpy/tools/_dendrogram.py index f60f0ae2e9..b31e792f31 100644 --- a/src/scanpy/tools/_dendrogram.py +++ b/src/scanpy/tools/_dendrogram.py @@ -124,15 +124,17 @@ def dendrogram( groupby = [groupby] for group in groupby: if group not in adata.obs_keys(): - raise ValueError( + msg = ( "groupby has to be a valid observation. " f"Given value: {group}, valid observations: {adata.obs_keys()}" ) + raise ValueError(msg) if not isinstance(adata.obs[group].dtype, CategoricalDtype): - raise ValueError( + msg = ( "groupby has to be a categorical observation. " f"Given value: {group}, Column type: {adata.obs[group].dtype}" ) + raise ValueError(msg) if var_names is None: rep_df = pd.DataFrame( @@ -188,7 +190,7 @@ def dendrogram( if inplace: if key_added is None: - key_added = f'dendrogram_{"_".join(groupby)}' + key_added = f"dendrogram_{'_'.join(groupby)}" logg.info(f"Storing dendrogram info using `.uns[{key_added!r}]`") adata.uns[key_added] = dat else: diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index d2bdcc647b..8554552252 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -77,11 +77,11 @@ def diffmap( neighbors_key = "neighbors" if neighbors_key not in adata.uns: - raise ValueError( - "You need to run `pp.neighbors` first to compute a neighborhood graph." - ) + msg = "You need to run `pp.neighbors` first to compute a neighborhood graph." + raise ValueError(msg) if n_comps <= 2: - raise ValueError("Provide any value greater than 2 for `n_comps`. ") + msg = "Provide any value greater than 2 for `n_comps`. " + raise ValueError(msg) adata = adata.copy() if copy else adata _diffmap( adata, n_comps=n_comps, neighbors_key=neighbors_key, random_state=random_state diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index c0fa59262f..e92fc726c6 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -18,7 +18,7 @@ def _diffmap(adata, n_comps=15, neighbors_key=None, random_state=0): - start = logg.info(f"computing Diffusion Maps using n_comps={n_comps}(=n_dcs)") + start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() dpt.compute_eigen(n_comps=n_comps, random_state=random_state) @@ -129,7 +129,8 @@ def dpt( if neighbors_key is None: neighbors_key = "neighbors" if neighbors_key not in adata.uns: - raise ValueError("You need to run `pp.neighbors` and `tl.diffmap` first.") + msg = "You need to run `pp.neighbors` and `tl.diffmap` first." + raise ValueError(msg) if "iroot" not in adata.uns and "xroot" not in adata.var: logg.warning( "No root cell found. To compute pseudotime, pass the index or " @@ -152,7 +153,7 @@ def dpt( allow_kendall_tau_shift=allow_kendall_tau_shift, neighbors_key=neighbors_key, ) - start = logg.info(f"computing Diffusion Pseudotime using n_dcs={n_dcs}") + start = logg.info(f"computing Diffusion Pseudotime using {n_dcs=}") if n_branchings > 1: logg.info(" this uses a hierarchical implementation") if dpt.iroot is not None: @@ -262,7 +263,7 @@ def detect_branchings(self): """ logg.debug( f" detect {self.n_branchings} " - f'branching{"" if self.n_branchings == 1 else "s"}', + f"branching{'' if self.n_branchings == 1 else 's'}", ) # a segment is a subset of points of the data set (defined by the # indices of the points in the segment) @@ -799,9 +800,8 @@ def _detect_branching( elif self.flavor == "wolf17_bi" or self.flavor == "wolf17_bi_un": ssegs = self._detect_branching_single_wolf17_bi(Dseg, tips) else: - raise ValueError( - '`flavor` needs to be in {"haghverdi16", "wolf17_tri", "wolf17_bi"}.' - ) + msg = '`flavor` needs to be in {"haghverdi16", "wolf17_tri", "wolf17_bi"}.' + raise ValueError(msg) # make sure that each data point has a unique association with a segment masks = np.zeros((len(ssegs), Dseg.shape[0]), dtype=bool) for iseg, seg in enumerate(ssegs): @@ -1039,9 +1039,11 @@ def kendall_tau_split(self, a: np.ndarray, b: np.ndarray) -> int: Splitting index according to above description. """ if a.size != b.size: - raise ValueError("a and b need to have the same size") + msg = "a and b need to have the same size" + raise ValueError(msg) if a.ndim != b.ndim != 1: - raise ValueError("a and b need to be one-dimensional arrays") + msg = "a and b need to be one-dimensional arrays" + raise ValueError(msg) import scipy as sp min_length = 5 diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index aedd41f3d3..d0a70b3f4f 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -124,7 +124,8 @@ def draw_graph( """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") if layout not in (layouts := get_literal_vals(_Layout)): - raise ValueError(f"Provide a valid layout, one of {layouts}.") + msg = f"Provide a valid layout, one of {layouts}." + raise ValueError(msg) adata = adata.copy() if copy else adata if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) diff --git a/src/scanpy/tools/_embedding_density.py b/src/scanpy/tools/_embedding_density.py index 5ae69361dc..d539848b98 100644 --- a/src/scanpy/tools/_embedding_density.py +++ b/src/scanpy/tools/_embedding_density.py @@ -130,10 +130,11 @@ def embedding_density( basis = "draw_graph_fa" if f"X_{basis}" not in adata.obsm_keys(): - raise ValueError( + msg = ( "Cannot find the embedded representation " f"`adata.obsm['X_{basis}']`. Compute the embedding first." ) + raise ValueError(msg) if components is None: components = "1,2" @@ -142,17 +143,20 @@ def embedding_density( components = np.array(components).astype(int) - 1 if len(components) != 2: - raise ValueError("Please specify exactly 2 components, or `None`.") + msg = "Please specify exactly 2 components, or `None`." + raise ValueError(msg) if basis == "diffmap": components += 1 if groupby is not None: if groupby not in adata.obs: - raise ValueError(f"Could not find {groupby!r} `.obs` column.") + msg = f"Could not find {groupby!r} `.obs` column." + raise ValueError(msg) if adata.obs[groupby].dtype.name != "category": - raise ValueError(f"{groupby!r} column does not contain categorical data") + msg = f"{groupby!r} column does not contain categorical data" + raise ValueError(msg) # Define new covariate name if key_added is not None: diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 3698067035..2a47e095a0 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -123,11 +123,12 @@ def ingest( # anndata version check anndata_version = pkg_version("anndata") if anndata_version < ANNDATA_MIN_VERSION: - raise ValueError( + msg = ( f"ingest only works correctly with anndata>={ANNDATA_MIN_VERSION} " f"(you have {anndata_version}) as prior to {ANNDATA_MIN_VERSION}, " "`AnnData.concatenate` did not concatenate `.obsm`." ) + raise ValueError(msg) start = logg.info("running ingest") obs = [obs] if isinstance(obs, str) else obs @@ -187,12 +188,13 @@ def __init__(self, dim, axis=0, vals=None): def __setitem__(self, key, value): if value.shape[self._axis] != self._dim: - raise ValueError( - f"Value passed for key '{key}' is of incorrect shape. " + msg = ( + f"Value passed for key {key!r} is of incorrect shape. " f"Value has shape {value.shape[self._axis]} " f"for dimension {self._axis} while " f"it should have {self._dim}." ) + raise ValueError(msg) self._data[key] = value def __getitem__(self, key): @@ -340,10 +342,11 @@ def __init__(self, adata: AnnData, neighbors_key: str | None = None): if neighbors_key in adata.uns: self._init_neighbors(adata, neighbors_key) else: - raise ValueError( + msg = ( f'There is no neighbors data in `adata.uns["{neighbors_key}"]`.\n' "Please run pp.neighbors." ) + raise ValueError(msg) if "X_umap" in adata.obsm: self._init_umap(adata) @@ -393,10 +396,11 @@ def fit(self, adata_new): new_var_names = adata_new.var_names.str.upper() if not ref_var_names.equals(new_var_names): - raise ValueError( + msg = ( "Variables in the new adata are different " "from variables in the reference adata" ) + raise ValueError(msg) self._obs = pd.DataFrame(index=adata_new.obs.index) self._obsm = _DimDict(adata_new.n_obs, axis=0) @@ -440,9 +444,8 @@ def map_embedding(self, method): elif method == "pca": self._obsm["X_pca"] = self._pca() else: - raise NotImplementedError( - "Ingest supports only umap and pca embeddings for now." - ) + msg = "Ingest supports only umap and pca embeddings for now." + raise NotImplementedError(msg) def _knn_classify(self, labels): # ensure it's categorical @@ -461,7 +464,8 @@ def map_labels(self, labels, method): if method == "knn": self._obs[labels] = self._knn_classify(labels) else: - raise NotImplementedError("Ingest supports knn labeling for now.") + msg = "Ingest supports knn labeling for now." + raise NotImplementedError(msg) @old_positionals("inplace") def to_adata(self, *, inplace: bool = False) -> AnnData | None: diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index f73ec1fd7d..9f1fbf23ef 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -120,19 +120,18 @@ def leiden( and `n_iterations`. """ if flavor not in {"igraph", "leidenalg"}: - raise ValueError( - f"flavor must be either 'igraph' or 'leidenalg', but '{flavor}' was passed" + msg = ( + f"flavor must be either 'igraph' or 'leidenalg', but {flavor!r} was passed" ) + raise ValueError(msg) _utils.ensure_igraph() if flavor == "igraph": if directed: - raise ValueError( - "Cannot use igraph’s leiden implementation with a directed graph." - ) + msg = "Cannot use igraph’s leiden implementation with a directed graph." + raise ValueError(msg) if partition_type is not None: - raise ValueError( - "Do not pass in partition_type argument when using igraph." - ) + msg = "Do not pass in partition_type argument when using igraph." + raise ValueError(msg) else: try: import leidenalg @@ -140,9 +139,8 @@ def leiden( msg = 'In the future, the default backend for leiden will be igraph instead of leidenalg.\n\n To achieve the future defaults please pass: flavor="igraph" and n_iterations=2. directed must also be False to work with igraph\'s implementation.' _utils.warn_once(msg, FutureWarning, stacklevel=3) except ImportError: - raise ImportError( - "Please install the leiden algorithm: `conda install -c conda-forge leidenalg` or `pip3 install leidenalg`." - ) + msg = "Please install the leiden algorithm: `conda install -c conda-forge leidenalg` or `pip3 install leidenalg`." + raise ImportError(msg) clustering_args = dict(clustering_args) start = logg.info("running Leiden clustering") diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index 470858ff38..50181229ab 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -143,9 +143,8 @@ def louvain( partition_kwargs = dict(partition_kwargs) start = logg.info("running Louvain clustering") if (flavor != "vtraag") and (partition_type is not None): - raise ValueError( - "`partition_type` is only a valid argument " 'when `flavour` is "vtraag"' - ) + msg = '`partition_type` is only a valid argument when `flavour` is "vtraag"' + raise ValueError(msg) adata = adata.copy() if copy else adata if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) @@ -239,7 +238,8 @@ def louvain( for k, v in partition.items(): groups[k] = v else: - raise ValueError('`flavor` needs to be "vtraag" or "igraph" or "taynaud".') + msg = '`flavor` needs to be "vtraag" or "igraph" or "taynaud".' + raise ValueError(msg) if restrict_to is not None: if key_added == "louvain": key_added += "_R" diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index eb07b84885..1860fd73df 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -162,30 +162,35 @@ def marker_gene_overlap( """ # Test user inputs if inplace: - raise NotImplementedError( + msg = ( "Writing Pandas dataframes to h5ad is currently under development." "\nPlease use `inplace=False`." ) + raise NotImplementedError(msg) if key not in adata.uns: - raise ValueError( + msg = ( "Could not find marker gene data. " "Please run `sc.tl.rank_genes_groups()` first." ) + raise ValueError(msg) avail_methods = {"overlap_count", "overlap_coef", "jaccard", "enrich"} if method not in avail_methods: - raise ValueError(f"Method must be one of {avail_methods}.") + msg = f"Method must be one of {avail_methods}." + raise ValueError(msg) if normalize == "None": normalize = None avail_norm = {"reference", "data", None} if normalize not in avail_norm: - raise ValueError(f"Normalize must be one of {avail_norm}.") + msg = f"Normalize must be one of {avail_norm}." + raise ValueError(msg) if normalize is not None and method != "overlap_count": - raise ValueError("Can only normalize with method=`overlap_count`.") + msg = "Can only normalize with method=`overlap_count`." + raise ValueError(msg) if not all(isinstance(val, AbstractSet) for val in reference_markers.values()): try: @@ -193,18 +198,20 @@ def marker_gene_overlap( key: set(val) for key, val in reference_markers.items() } except Exception: - raise ValueError( + msg = ( "Please ensure that `reference_markers` contains " "sets or lists of markers as values." ) + raise ValueError(msg) if adj_pval_threshold is not None: if "pvals_adj" not in adata.uns[key]: - raise ValueError( + msg = ( "Could not find adjusted p-value data. " "Please run `sc.tl.rank_genes_groups()` with a " "method that outputs adjusted p-values." ) + raise ValueError(msg) if adj_pval_threshold < 0: logg.warning( diff --git a/src/scanpy/tools/_paga.py b/src/scanpy/tools/_paga.py index 98146b83e2..b7f1e86e5d 100644 --- a/src/scanpy/tools/_paga.py +++ b/src/scanpy/tools/_paga.py @@ -107,21 +107,22 @@ def paga( """ check_neighbors = "neighbors" if neighbors_key is None else neighbors_key if check_neighbors not in adata.uns: - raise ValueError( - "You need to run `pp.neighbors` first to compute a neighborhood graph." - ) + msg = "You need to run `pp.neighbors` first to compute a neighborhood graph." + raise ValueError(msg) if groups is None: for k in ("leiden", "louvain"): if k in adata.obs.columns: groups = k break if groups is None: - raise ValueError( + msg = ( "You need to run `tl.leiden` or `tl.louvain` to compute " "community labels, or specify `groups='an_existing_key'`" ) + raise ValueError(msg) elif groups not in adata.obs.columns: - raise KeyError(f"`groups` key {groups!r} not found in `adata.obs`.") + msg = f"`groups` key {groups!r} not found in `adata.obs`." + raise KeyError(msg) adata = adata.copy() if copy else adata _utils.sanitize_anndata(adata) @@ -170,9 +171,8 @@ def compute_connectivities(self): elif self._model == "v1.0": return self._compute_connectivities_v1_0() else: - raise ValueError( - f"`model` {self._model} needs to be one of {_AVAIL_MODELS}." - ) + msg = f"`model` {self._model} needs to be one of {_AVAIL_MODELS}." + raise ValueError(msg) def _compute_connectivities_v1_2(self): import igraph @@ -273,15 +273,17 @@ def compute_transitions(self): "The key 'velocyto_transitions' has been changed to 'velocity_graph'." ) else: - raise ValueError( + msg = ( "The passed AnnData needs to have an `uns` annotation " "with key 'velocity_graph' - a sparse matrix from RNA velocity." ) + raise ValueError(msg) if self._adata.uns[vkey].shape != (self._adata.n_obs, self._adata.n_obs): - raise ValueError( + msg = ( f"The passed 'velocity_graph' have shape {self._adata.uns[vkey].shape} " f"but shoud have shape {(self._adata.n_obs, self._adata.n_obs)}" ) + raise ValueError(msg) # restore this at some point # if 'expected_n_edges_random' not in self._adata.uns['paga']: # raise ValueError( diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index cafb78c6f1..05e5738d99 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -132,7 +132,8 @@ def __init__( adata_comp = adata if layer is not None: if use_raw: - raise ValueError("Cannot specify `layer` and have `use_raw=True`.") + msg = "Cannot specify `layer` and have `use_raw=True`." + raise ValueError(msg) X = adata_comp.layers[layer] else: if use_raw and adata.raw is not None: @@ -253,7 +254,8 @@ def t_test( # hack for overestimating the variance for small groups ns_rest = ns_group else: - raise ValueError("Method does not exist.") + msg = "Method does not exist." + raise ValueError(msg) # TODO: Come up with better solution. Mask unexpressed genes? # See https://github.com/scipy/scipy/issues/10269 @@ -369,7 +371,8 @@ def logreg( X = self.X[self.grouping_mask.values, :] if len(self.groups_order) == 1: - raise ValueError("Cannot perform logistic regression on a single cluster.") + msg = "Cannot perform logistic regression on a single cluster." + raise ValueError(msg) clf = LogisticRegression(**kwds) clf.fit(X, self.grouping.cat.codes) @@ -598,7 +601,8 @@ def rank_genes_groups( if use_raw is None: use_raw = adata.raw is not None elif use_raw is True and adata.raw is None: - raise ValueError("Received `use_raw=True`, but `adata.raw` is empty.") + msg = "Received `use_raw=True`, but `adata.raw` is empty." + raise ValueError(msg) if method is None: method = "t-test" @@ -608,11 +612,13 @@ def rank_genes_groups( start = logg.info("ranking genes") if method not in (avail_methods := get_literal_vals(_Method)): - raise ValueError(f"Method must be one of {avail_methods}.") + msg = f"Method must be one of {avail_methods}." + raise ValueError(msg) avail_corr = {"benjamini-hochberg", "bonferroni"} if corr_method not in avail_corr: - raise ValueError(f"Correction method must be one of {avail_corr}.") + msg = f"Correction method must be one of {avail_corr}." + raise ValueError(msg) adata = adata.copy() if copy else adata _utils.sanitize_anndata(adata) @@ -620,7 +626,8 @@ def rank_genes_groups( if groups == "all": groups_order = "all" elif isinstance(groups, str | int): - raise ValueError("Specify a sequence of groups") + msg = "Specify a sequence of groups" + raise ValueError(msg) else: groups_order = list(groups) if isinstance(groups_order[0], int): @@ -629,9 +636,8 @@ def rank_genes_groups( groups_order += [reference] if reference != "rest" and reference not in adata.obs[groupby].cat.categories: cats = adata.obs[groupby].cat.categories.tolist() - raise ValueError( - f"reference = {reference} needs to be one of groupby = {cats}." - ) + msg = f"reference = {reference} needs to be one of groupby = {cats}." + raise ValueError(msg) if key_added is None: key_added = "rank_genes_groups" diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index a40d9f3288..d0a33fdb97 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -38,7 +38,8 @@ def _sparse_nanmean( np.nanmean equivalent for sparse matrices """ if not issparse(X): - raise TypeError("X must be a sparse matrix") + msg = "X must be a sparse matrix" + raise TypeError(msg) # count the number of nan elements per row/column (dep. on axis) Z = X.copy() @@ -130,9 +131,8 @@ def score_genes( adata = adata.copy() if copy else adata use_raw = _check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: - raise NotImplementedError( - f"score_genes is not implemented for matrices of type {type(adata.X)}" - ) + msg = f"score_genes is not implemented for matrices of type {type(adata.X)}" + raise NotImplementedError(msg) if random_state is not None: np.random.seed(random_state) @@ -204,14 +204,16 @@ def _check_score_genes_args( if len(genes_to_ignore) > 0: logg.warning(f"genes are not in var_names and ignored: {genes_to_ignore}") if len(gene_list) == 0: - raise ValueError("No valid genes were passed for scoring.") + msg = "No valid genes were passed for scoring." + raise ValueError(msg) if gene_pool is None: gene_pool = var_names.astype("string") else: gene_pool = pd.Index(gene_pool, dtype="string").intersection(var_names) if len(gene_pool) == 0: - raise ValueError("No valid genes were passed for reference set.") + msg = "No valid genes were passed for reference set." + raise ValueError(msg) def get_subset(genes: pd.Index[str]): x = _get_obs_rep(adata, use_raw=use_raw, layer=layer) diff --git a/src/scanpy/tools/_sim.py b/src/scanpy/tools/_sim.py index 7410442952..a53575fa29 100644 --- a/src/scanpy/tools/_sim.py +++ b/src/scanpy/tools/_sim.py @@ -120,7 +120,7 @@ def add_args(p): "default": "", "metavar": "f", "type": str, - "help": "Specify a parameter file " '(default: "sim/${exkey}_params.txt")', + "help": 'Specify a parameter file (default: "sim/${exkey}_params.txt")', } } p = _utils.add_args(p, dadd_args) @@ -216,7 +216,7 @@ def sample_dynamic_data(**params): break logg.debug( f"mean nr of offdiagonal edges {nrOffEdges_list.mean()} " - f"compared to total nr {grnsim.dim * (grnsim.dim - 1) / 2.}" + f"compared to total nr {grnsim.dim * (grnsim.dim - 1) / 2.0}" ) # more complex models @@ -358,15 +358,13 @@ def write_data( for g in range(dim): if np.abs(Coupl[gp, g]) > 1e-10: f.write( - f"{names[gp]:10} " - f"{names[g]:10} " - f"{Coupl[gp, g]:10.3} \n" + f"{names[gp]:10} {names[g]:10} {Coupl[gp, g]:10.3} \n" ) # write simulated data # the binary mode option in the following line is a fix for python 3 # variable names if varNames: - header += f'{"it":>2} ' + header += f"{'it':>2} " for v in varNames: header += f"{v:>7} " with (dir / f"sim_{id}.txt").open("ab" if append else "wb") as f: @@ -429,7 +427,8 @@ def __init__( self.verbosity = verbosity # checks if initType not in ["branch", "random"]: - raise RuntimeError("initType must be either: branch, random") + msg = "initType must be either: branch, random" + raise RuntimeError(msg) if model not in self.availModels: message = "model not among predefined models \n" # noqa: F841 # TODO FIX # read from file @@ -437,7 +436,8 @@ def __init__( model = Path(sim_models.__file__).parent / f"{model}.txt" if not model.is_file(): - raise RuntimeError(f"Model file {model} does not exist") + msg = f"Model file {model} does not exist" + raise RuntimeError(msg) self.model = model # set the coupling matrix, and with that the adjacency matrix self.set_coupl(Coupl=Coupl) @@ -461,7 +461,8 @@ def sim_model(self, tmax, X0, noiseDyn=0, restart=0): elif self.modelType == "var": Xdiff = self.Xdiff_var(X[t - 1]) else: - raise ValueError(f"Unknown modelType {self.modelType!r}") + msg = f"Unknown modelType {self.modelType!r}" + raise ValueError(msg) X[t] = X[t - 1] + Xdiff # add dynamic noise X[t] += noiseDyn * np.random.randn(self.dim) @@ -501,7 +502,7 @@ def Xdiff_hill(self, Xt): ) if verbosity > 0: Xdiff_syn_tuple_str += ( - f'{"a" if v else "i"}' + f"{'a' if v else 'i'}" f"({self.pas[child][iv]}, {threshold:.2})" ) Xdiff_syn += Xdiff_syn_tuple @@ -853,12 +854,12 @@ def build_boolCoeff(self): for g in range(self.dim): if g in pasIndices: if np.abs(self.Coupl[self.varNames[key], g]) < 1e-10: - raise ValueError(f"specify coupling value for {key} <- {g}") + msg = f"specify coupling value for {key} <- {g}" + raise ValueError(msg) else: if np.abs(self.Coupl[self.varNames[key], g]) > 1e-10: - raise ValueError( - "there should be no coupling value for " f"{key} <- {g}" - ) + msg = f"there should be no coupling value for {key} <- {g}" + raise ValueError(msg) if self.verbosity > 1: settings.m(0, "..." + key) settings.m(0, rule) @@ -957,7 +958,7 @@ def _check_branching( check = False if check: Xsamples.append(X) - logg.debug(f'realization {restart}: {"" if check else "no"} new branch') + logg.debug(f"realization {restart}: {'' if check else 'no'} new branch") return check, Xsamples @@ -1047,9 +1048,8 @@ def sample_coupling_matrix( check = True break if not check: - raise ValueError( - "did not find graph without cycles after" f"{max_trial} trials" - ) + msg = f"did not find graph without cycles after {max_trial} trials" + raise ValueError(msg) return Coupl, Adj, Adj_signed, n_edges diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 902171d58c..926e6d3d4f 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -164,9 +164,8 @@ def umap( if neighbors_key is None: # backwards compat neighbors_key = "neighbors" if neighbors_key not in adata.uns: - raise ValueError( - f"Did not find .uns[{neighbors_key!r}]. Run `sc.pp.neighbors` first." - ) + msg = f"Did not find .uns[{neighbors_key!r}]. Run `sc.pp.neighbors` first." + raise ValueError(msg) start = logg.info("computing UMAP") @@ -241,10 +240,11 @@ def umap( warnings.warn(msg, FutureWarning) metric = neigh_params.get("metric", "euclidean") if metric != "euclidean": - raise ValueError( + msg = ( f"`sc.pp.neighbors` was called with `metric` {metric!r}, " "but umap `method` 'rapids' only supports the 'euclidean' metric." ) + raise ValueError(msg) from cuml import UMAP n_neighbors = neighbors["params"]["n_neighbors"] diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index 97e2de0df1..4d24b5e276 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -32,9 +32,8 @@ def _choose_representation( if adata.n_vars > settings.N_PCS: if "X_pca" in adata.obsm: if n_pcs is not None and n_pcs > adata.obsm["X_pca"].shape[1]: - raise ValueError( - "`X_pca` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." - ) + msg = "`X_pca` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." + raise ValueError(msg) X = adata.obsm["X_pca"][:, :n_pcs] logg.info(f" using 'X_pca' with n_pcs = {X.shape[1]}") else: @@ -52,21 +51,23 @@ def _choose_representation( else: if use_rep in adata.obsm and n_pcs is not None: if n_pcs > adata.obsm[use_rep].shape[1]: - raise ValueError( + msg = ( f"{use_rep} does not have enough Dimensions. Provide a " "Representation with equal or more dimensions than" "`n_pcs` or lower `n_pcs` " ) + raise ValueError(msg) X = adata.obsm[use_rep][:, :n_pcs] elif use_rep in adata.obsm and n_pcs is None: X = adata.obsm[use_rep] elif use_rep == "X": X = adata.X else: - raise ValueError( + msg = ( f"Did not find {use_rep} in `.obsm.keys()`. " "You need to compute it first." ) + raise ValueError(msg) settings.verbosity = verbosity # resetting verbosity return X @@ -86,7 +87,7 @@ def preprocess_with_pca(adata, n_pcs: int | None = None, random_state=0): logg.info(" using data matrix X directly (no PCA)") return adata.X elif n_pcs is None and "X_pca" in adata.obsm_keys(): - logg.info(f' using \'X_pca\' with n_pcs = {adata.obsm["X_pca"].shape[1]}') + logg.info(f" using 'X_pca' with n_pcs = {adata.obsm['X_pca'].shape[1]}") return adata.obsm["X_pca"] elif "X_pca" in adata.obsm_keys() and adata.obsm["X_pca"].shape[1] >= n_pcs: logg.info(f" using 'X_pca' with n_pcs = {n_pcs}") @@ -128,5 +129,6 @@ def get_init_pos_from_paga( else: init_pos[subset] = group_pos else: - raise ValueError("Plot PAGA first, so that adata.uns['paga']" "with key 'pos'.") + msg = "Plot PAGA first, so that adata.uns['paga'] with key 'pos'." + raise ValueError(msg) return init_pos diff --git a/src/scanpy/tools/_utils_clustering.py b/src/scanpy/tools/_utils_clustering.py index 47f652fbdf..3c771e5d74 100644 --- a/src/scanpy/tools/_utils_clustering.py +++ b/src/scanpy/tools/_utils_clustering.py @@ -37,12 +37,12 @@ def restrict_adjacency( adjacency: spmatrix, ) -> tuple[spmatrix, NDArray[np.bool_]]: if not isinstance(restrict_categories[0], str): - raise ValueError( - "You need to use strings to label categories, " "e.g. '1' instead of 1." - ) + msg = "You need to use strings to label categories, e.g. '1' instead of 1." + raise ValueError(msg) for c in restrict_categories: if c not in adata.obs[restrict_key].cat.categories: - raise ValueError(f"'{c}' is not a valid category for '{restrict_key}'") + msg = f"{c!r} is not a valid category for {restrict_key!r}" + raise ValueError(msg) restrict_indices = adata.obs[restrict_key].isin(restrict_categories).values adjacency = adjacency[restrict_indices, :] adjacency = adjacency[:, restrict_indices] diff --git a/src/testing/scanpy/_pytest/__init__.py b/src/testing/scanpy/_pytest/__init__.py index 318baac1aa..e365a90495 100644 --- a/src/testing/scanpy/_pytest/__init__.py +++ b/src/testing/scanpy/_pytest/__init__.py @@ -75,8 +75,7 @@ def pytest_addoption(parser: pytest.Parser) -> None: action="store_true", default=False, help=( - "Run tests that retrieve stuff from the internet. " - "This increases test time." + "Run tests that retrieve stuff from the internet. This increases test time." ), ) @@ -131,6 +130,6 @@ def pytest_itemcollected(item: pytest.Item) -> None: ) -assert ( - "scanpy" not in sys.modules -), "scanpy is already imported, this will mess up test coverage" +assert "scanpy" not in sys.modules, ( + "scanpy is already imported, this will mess up test coverage" +) diff --git a/tests/conftest.py b/tests/conftest.py index 4cbe5ff53e..2d7f8e7aad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -133,7 +133,8 @@ def save_and_compare(*path_parts: Path | os.PathLike, tol: int): plt.savefig(actual_pth, dpi=40) plt.close() if not expected_pth.is_file(): - raise OSError(f"No expected output found at {expected_pth}.") + msg = f"No expected output found at {expected_pth}." + raise OSError(msg) check_same_image(expected_pth, actual_pth, tol=tol) return save_and_compare diff --git a/tests/external/test_wishbone.py b/tests/external/test_wishbone.py index 7fadef63c6..db649a5b9d 100644 --- a/tests/external/test_wishbone.py +++ b/tests/external/test_wishbone.py @@ -22,6 +22,6 @@ def test_run_wishbone(): components=[2, 3], num_waypoints=150, ) - assert all( - [k in adata.obs for k in ["trajectory_wishbone", "branch_wishbone"]] - ), "Run Wishbone Error!" + assert all([k in adata.obs for k in ["trajectory_wishbone", "branch_wishbone"]]), ( + "Run Wishbone Error!" + ) diff --git a/tests/test_dendrogram.py b/tests/test_dendrogram.py index 18b952eff2..44a08fcf67 100644 --- a/tests/test_dendrogram.py +++ b/tests/test_dendrogram.py @@ -18,7 +18,7 @@ def test_dendrogram_key_added(groupby, key_added): adata = pbmc68k_reduced() sc.tl.dendrogram(adata, groupby=groupby, key_added=key_added, use_rep="X_pca") if isinstance(groupby, list): - dendrogram_key = f'dendrogram_{"_".join(groupby)}' + dendrogram_key = f"dendrogram_{'_'.join(groupby)}" else: dendrogram_key = f"dendrogram_{groupby}" diff --git a/tests/test_get.py b/tests/test_get.py index 673b26787d..05cb1b6a9d 100644 --- a/tests/test_get.py +++ b/tests/test_get.py @@ -24,7 +24,7 @@ def transpose_adata(adata: AnnData, *, expect_duplicates: bool = False) -> AnnDa TRANSPOSE_PARAMS = pytest.mark.parametrize( - "dim,transform,func", + ("dim", "transform", "func"), [ ("obs", lambda x, expect_duplicates=False: x, sc.get.obs_df), ("var", transpose_adata, sc.get.var_df), diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 7d9fdac9fa..528a86ea99 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -629,7 +629,8 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): pass else: - raise ValueError(f"Unknown flavor {flavor}") + msg = f"Unknown flavor {flavor}" + raise ValueError(msg) n_genes = adata.shape[1] diff --git a/tests/test_normalization.py b/tests/test_normalization.py index 3acefe1bb1..9cf20c0b52 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -198,12 +198,12 @@ def _check_pearson_pca_fields(ad, n_cells, n_comps): "Missing `.uns` keys. Expected `['pearson_residuals_normalization', 'pca']`, " f"but only {list(ad.uns.keys())} were found" ) - assert ( - "X_pca" in ad.obsm - ), f"Missing `obsm` key `'X_pca'`, only {list(ad.obsm.keys())} were found" - assert ( - "PCs" in ad.varm - ), f"Missing `varm` key `'PCs'`, only {list(ad.varm.keys())} were found" + assert "X_pca" in ad.obsm, ( + f"Missing `obsm` key `'X_pca'`, only {list(ad.obsm.keys())} were found" + ) + assert "PCs" in ad.varm, ( + f"Missing `varm` key `'PCs'`, only {list(ad.varm.keys())} were found" + ) assert ad.obsm["X_pca"].shape == ( n_cells, n_comps, diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index 788c7e705d..b938fd2ca3 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -59,14 +59,12 @@ def get_example_data(array_type: Callable[[np.ndarray], Any]) -> AnnData: return adata -def get_true_scores() -> ( - tuple[ - NDArray[np.object_], - NDArray[np.object_], - NDArray[np.floating], - NDArray[np.floating], - ] -): +def get_true_scores() -> tuple[ + NDArray[np.object_], + NDArray[np.object_], + NDArray[np.floating], + NDArray[np.floating], +]: with (DATA_PATH / "objs_t_test.pkl").open("rb") as f: true_scores_t_test, true_names_t_test = pickle.load(f) with (DATA_PATH / "objs_wilcoxon.pkl").open("rb") as f: From fa38823a4918d5deaefa2bb42fd9d48053636f1a Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Wed, 15 Jan 2025 06:14:04 -0500 Subject: [PATCH 19/24] Grammar fixes in `sc.tl` docstrings (#3438) * typo and grammar fixes in docstrings only * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert dendogram docstring --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/scanpy/tools/_dendrogram.py | 10 +++++----- src/scanpy/tools/_diffmap.py | 18 +++++++++--------- src/scanpy/tools/_dpt.py | 20 ++++++++++---------- src/scanpy/tools/_draw_graph.py | 8 ++++---- src/scanpy/tools/_embedding_density.py | 2 +- src/scanpy/tools/_ingest.py | 4 ++-- src/scanpy/tools/_leiden.py | 8 ++++---- src/scanpy/tools/_louvain.py | 4 ++-- src/scanpy/tools/_marker_gene_overlap.py | 2 +- src/scanpy/tools/_score_genes.py | 4 ++-- src/scanpy/tools/_sim.py | 2 +- src/scanpy/tools/_top_genes.py | 4 ++-- src/scanpy/tools/_tsne.py | 2 +- 13 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/scanpy/tools/_dendrogram.py b/src/scanpy/tools/_dendrogram.py index b31e792f31..f33aca1ff7 100644 --- a/src/scanpy/tools/_dendrogram.py +++ b/src/scanpy/tools/_dendrogram.py @@ -60,8 +60,8 @@ def dendrogram( to compute a correlation matrix. The hierarchical clustering can be visualized using - :func:`scanpy.pl.dendrogram` or multiple other visualizations that can - include a dendrogram: :func:`~scanpy.pl.matrixplot`, + :func:`scanpy.pl.dendrogram` or multiple other visualizations + that can include a dendrogram: :func:`~scanpy.pl.matrixplot`, :func:`~scanpy.pl.heatmap`, :func:`~scanpy.pl.dotplot`, and :func:`~scanpy.pl.stacked_violin`. @@ -78,15 +78,15 @@ def dendrogram( {use_rep} var_names List of var_names to use for computing the hierarchical clustering. - If `var_names` is given, then `use_rep` and `n_pcs` is ignored. + If `var_names` is given, then `use_rep` and `n_pcs` are ignored. use_raw Only when `var_names` is not None. Use `raw` attribute of `adata` if present. cor_method - correlation method to use. + Correlation method to use. Options are 'pearson', 'kendall', and 'spearman' linkage_method - linkage method to use. See :func:`scipy.cluster.hierarchy.linkage` + Linkage method to use. See :func:`scipy.cluster.hierarchy.linkage` for more information. optimal_ordering Same as the optimal_ordering argument of :func:`scipy.cluster.hierarchy.linkage` diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 8554552252..b69c2ef18f 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -23,9 +23,9 @@ def diffmap( """\ Diffusion Maps :cite:p:`Coifman2005,Haghverdi2015,Wolf2018`. - Diffusion maps :cite:p:`Coifman2005` has been proposed for visualizing single-cell - data by :cite:t:`Haghverdi2015`. The tool uses the adapted Gaussian kernel suggested - by :cite:t:`Haghverdi2016` in the implementation of :cite:t:`Wolf2018`. + Diffusion maps :cite:p:`Coifman2005` have been proposed for visualizing single-cell + data by :cite:t:`Haghverdi2015`. This tool uses the adapted Gaussian kernel suggested + by :cite:t:`Haghverdi2016` with the implementation of :cite:t:`Wolf2018`. The width ("sigma") of the connectivity kernel is implicitly determined by the number of neighbors used to compute the single-cell graph in @@ -42,12 +42,12 @@ def diffmap( n_comps The number of dimensions of the representation. neighbors_key - If not specified, diffmap looks .uns['neighbors'] for neighbors settings - and .obsp['connectivities'], .obsp['distances'] for connectivities and - distances respectively (default storage places for pp.neighbors). - If specified, diffmap looks .uns[neighbors_key] for neighbors settings and - .obsp[.uns[neighbors_key]['connectivities_key']], - .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances + If not specified, diffmap looks in .uns['neighbors'] for neighbors settings + and .obsp['connectivities'] and .obsp['distances'] for connectivities and + distances, respectively (default storage places for pp.neighbors). + If specified, diffmap looks in .uns[neighbors_key] for neighbors settings and + .obsp[.uns[neighbors_key]['connectivities_key']] and + .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. random_state A numpy random seed diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index e92fc726c6..a9adc2a112 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -53,7 +53,7 @@ def dpt( :cite:p:`Haghverdi2016,Wolf2019`. Reconstruct the progression of a biological process from snapshot - data. `Diffusion Pseudotime` has been introduced by :cite:t:`Haghverdi2016` and + data. `Diffusion Pseudotime` was introduced by :cite:t:`Haghverdi2016` and implemented within Scanpy :cite:p:`Wolf2018`. Here, we use a further developed version, which is able to deal with disconnected graphs :cite:p:`Wolf2019` and can be run in a `hierarchical` mode by setting the parameter @@ -64,9 +64,9 @@ def dpt( adata.uns['iroot'] = np.flatnonzero(adata.obs['cell_types'] == 'Stem')[0] - This requires to run :func:`~scanpy.pp.neighbors`, first. In order to - reproduce the original implementation of DPT, use `method=='gauss'` in - this. Using the default `method=='umap'` only leads to minor quantitative + This requires running :func:`~scanpy.pp.neighbors`, first. In order to + reproduce the original implementation of DPT, use `method=='gauss'`. + Using the default `method=='umap'` only leads to minor quantitative differences, though. .. versionadded:: 1.1 @@ -96,12 +96,12 @@ def dpt( maximum correlation in Kendall tau criterion of :cite:t:`Haghverdi2016` to stabilize the splitting. neighbors_key - If not specified, dpt looks .uns['neighbors'] for neighbors settings - and .obsp['connectivities'], .obsp['distances'] for connectivities and - distances respectively (default storage places for pp.neighbors). - If specified, dpt looks .uns[neighbors_key] for neighbors settings and - .obsp[.uns[neighbors_key]['connectivities_key']], - .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances + If not specified, dpt looks in .uns['neighbors'] for neighbors settings + and .obsp['connectivities'] and .obsp['distances'] for connectivities and + distances, respectively (default storage places for pp.neighbors). + If specified, dpt looks in .uns[neighbors_key] for neighbors settings and + .obsp[.uns[neighbors_key]['connectivities_key']] and + .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. copy Copy instance before computation and return a copy. diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index d0a70b3f4f..0727715671 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -56,14 +56,14 @@ def draw_graph( Force-directed graph drawing :cite:p:`Islam2011,Jacomy2014,Chippada2018`. An alternative to tSNE that often preserves the topology of the data - better. This requires to run :func:`~scanpy.pp.neighbors`, first. + better. This requires running :func:`~scanpy.pp.neighbors`, first. The default layout ('fa', `ForceAtlas2`, :cite:t:`Jacomy2014`) uses the package |fa2-modified|_ :cite:p:`Chippada2018`, which can be installed via `pip install fa2-modified`. `Force-directed graph drawing`_ describes a class of long-established algorithms for visualizing graphs. - It has been suggested for visualizing single-cell data by :cite:t:`Islam2011`. + It was suggested for visualizing single-cell data by :cite:t:`Islam2011`. Many other layouts as implemented in igraph :cite:p:`Csardi2006` are available. Similar approaches have been used by :cite:t:`Zunder2015` or :cite:t:`Weinreb2017`. @@ -98,9 +98,9 @@ def draw_graph( Use precomputed coordinates for initialization. If `False`/`None` (the default), initialize randomly. neighbors_key - If not specified, draw_graph looks .obsp['connectivities'] for connectivities + If not specified, draw_graph looks at .obsp['connectivities'] for connectivities (default storage place for pp.neighbors). - If specified, draw_graph looks + If specified, draw_graph looks at .obsp[.uns[neighbors_key]['connectivities_key']] for connectivities. obsp Use .obsp[obsp] as adjacency. You can't specify both diff --git a/src/scanpy/tools/_embedding_density.py b/src/scanpy/tools/_embedding_density.py index d539848b98..b930cc0aeb 100644 --- a/src/scanpy/tools/_embedding_density.py +++ b/src/scanpy/tools/_embedding_density.py @@ -70,7 +70,7 @@ def embedding_density( The annotated data matrix. basis The embedding over which the density will be calculated. This embedded - representation should be found in `adata.obsm['X_[basis]']``. + representation is found in `adata.obsm['X_[basis]']``. groupby Key for categorical observation/cell annotation for which densities are calculated per category. diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 2a47e095a0..256e1a97c6 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -91,10 +91,10 @@ def ingest( The method to map labels in `adata_ref.obs` to `adata.obs`. The only supported value is 'knn'. neighbors_key - If not specified, ingest looks adata_ref.uns['neighbors'] + If not specified, ingest looks at adata_ref.uns['neighbors'] for neighbors settings and adata_ref.obsp['distances'] for distances (default storage places for pp.neighbors). - If specified, ingest looks adata_ref.uns[neighbors_key] for + If specified, ingest looks at adata_ref.uns[neighbors_key] for neighbors settings and adata_ref.obsp[adata_ref.uns[neighbors_key]['distances_key']] for distances. inplace diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 9f1fbf23ef..cccd7f96ce 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -52,9 +52,9 @@ def leiden( Cluster cells using the Leiden algorithm :cite:p:`Traag2019`, an improved version of the Louvain algorithm :cite:p:`Blondel2008`. - It has been proposed for single-cell analysis by :cite:t:`Levine2015`. + It was proposed for single-cell analysis by :cite:t:`Levine2015`. - This requires having ran :func:`~scanpy.pp.neighbors` or + This requires having run :func:`~scanpy.pp.neighbors` or :func:`~scanpy.external.pp.bbknn` first. Parameters @@ -92,9 +92,9 @@ def leiden( :func:`~leidenalg.find_partition`. neighbors_key Use neighbors connectivities as adjacency. - If not specified, leiden looks .obsp['connectivities'] for connectivities + If not specified, leiden looks at .obsp['connectivities'] for connectivities (default storage place for pp.neighbors). - If specified, leiden looks + If specified, leiden looks at .obsp[.uns[neighbors_key]['connectivities_key']] for connectivities. obsp Use .obsp[obsp] as adjacency. You can't specify both diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index 50181229ab..1800dbe08e 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -69,10 +69,10 @@ def louvain( Cluster cells into subgroups :cite:p:`Blondel2008,Levine2015,Traag2017`. Cluster cells using the Louvain algorithm :cite:p:`Blondel2008` in the implementation - of :cite:t:`Traag2017`. The Louvain algorithm has been proposed for single-cell + of :cite:t:`Traag2017`. The Louvain algorithm was proposed for single-cell analysis by :cite:t:`Levine2015`. - This requires having ran :func:`~scanpy.pp.neighbors` or + This requires having run :func:`~scanpy.pp.neighbors` or :func:`~scanpy.external.pp.bbknn` first, or explicitly passing a ``adjacency`` matrix. diff --git a/src/scanpy/tools/_marker_gene_overlap.py b/src/scanpy/tools/_marker_gene_overlap.py index 1860fd73df..43408ff2c3 100644 --- a/src/scanpy/tools/_marker_gene_overlap.py +++ b/src/scanpy/tools/_marker_gene_overlap.py @@ -88,7 +88,7 @@ def marker_gene_overlap( inplace: bool = False, ): """\ - Calculate an overlap score between data-deriven marker genes and + Calculate an overlap score between data-derived marker genes and provided markers Marker gene overlap scores can be quoted as overlap counts, overlap diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index d0a33fdb97..7dff1a300c 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -79,8 +79,8 @@ def score_genes( """\ Score a set of genes :cite:p:`Satija2015`. - The score is the average expression of a set of genes subtracted with the - average expression of a reference set of genes. The reference set is + The score is the average expression of a set of genes after subtraction by + the average expression of a reference set of genes. The reference set is randomly sampled from the `gene_pool` for each binned expression value. This reproduces the approach in Seurat :cite:p:`Satija2015` and has been implemented diff --git a/src/scanpy/tools/_sim.py b/src/scanpy/tools/_sim.py index a53575fa29..f6ea2fede8 100644 --- a/src/scanpy/tools/_sim.py +++ b/src/scanpy/tools/_sim.py @@ -62,7 +62,7 @@ def sim( Sample from a stochastic differential equation model built from literature-curated boolean gene regulatory networks, as suggested by - :cite:t:`Wittmann2009`. The Scanpy implementation is due to :cite:t:`Wolf2018`. + :cite:t:`Wittmann2009`. The Scanpy implementation can be found in :cite:t:`Wolf2018`. Parameters ---------- diff --git a/src/scanpy/tools/_top_genes.py b/src/scanpy/tools/_top_genes.py index d66e9232f0..3b4e709b5f 100644 --- a/src/scanpy/tools/_top_genes.py +++ b/src/scanpy/tools/_top_genes.py @@ -38,7 +38,7 @@ def correlation_matrix( """\ Calculate correlation matrix. - Calculate a correlation matrix for genes strored in sample annotation + Calculate a correlation matrix for genes stored in sample annotation using :func:`~scanpy.tl.rank_genes_groups`. Parameters @@ -73,7 +73,7 @@ def correlation_matrix( spearman Spearman rank correlation annotation_key - Allows to define the name of the anndata entry where results are stored. + Allows defining the name of the anndata entry where results are stored. """ # TODO: At the moment, only works for int identifiers diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 18e4a47f8e..62fa8b9d57 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -47,7 +47,7 @@ def tsne( """\ t-SNE :cite:p:`vanDerMaaten2008,Amir2013,Pedregosa2011`. - t-distributed stochastic neighborhood embedding (tSNE, :cite:t:`vanDerMaaten2008`) has been + t-distributed stochastic neighborhood embedding (tSNE, :cite:t:`vanDerMaaten2008`) was proposed for visualizating single-cell data by :cite:t:`Amir2013`. Here, by default, we use the implementation of *scikit-learn* :cite:p:`Pedregosa2011`. You can achieve a huge speedup and better convergence if you install Multicore-tSNE_ From 05ab694ed1c8f7f12f274ec49a0991b59276eddb Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 16 Jan 2025 11:57:12 +0100 Subject: [PATCH 20/24] (chore): Replace spmatrix with _CSMatrix as appropriate (#3431) --- docs/conf.py | 3 - src/scanpy/_utils/__init__.py | 47 +++++++++----- src/scanpy/_utils/compute/is_constant.py | 18 +++--- src/scanpy/external/tl/_harmony_timeseries.py | 4 +- src/scanpy/external/tl/_palantir.py | 2 +- src/scanpy/get/_aggregated.py | 5 +- src/scanpy/get/get.py | 13 ++-- src/scanpy/metrics/_common.py | 35 ++++++---- src/scanpy/metrics/_gearys_c.py | 9 ++- src/scanpy/metrics/_morans_i.py | 9 ++- src/scanpy/neighbors/_types.py | 6 +- .../_deprecated/highly_variable_genes.py | 4 +- .../preprocessing/_deprecated/sampling.py | 8 +-- .../preprocessing/_highly_variable_genes.py | 2 +- src/scanpy/preprocessing/_pca/__init__.py | 17 ++--- src/scanpy/preprocessing/_pca/_compat.py | 6 +- src/scanpy/preprocessing/_pca/_dask_sparse.py | 10 ++- src/scanpy/preprocessing/_qc.py | 37 ++++++----- src/scanpy/preprocessing/_scale.py | 18 +++--- src/scanpy/preprocessing/_scrublet/core.py | 17 ++--- .../preprocessing/_scrublet/sparse_utils.py | 15 +++-- src/scanpy/preprocessing/_simple.py | 64 ++++++++++++------- src/scanpy/preprocessing/_utils.py | 16 +++-- src/scanpy/tools/_leiden.py | 5 +- src/scanpy/tools/_louvain.py | 4 +- src/scanpy/tools/_rank_genes_groups.py | 6 +- src/scanpy/tools/_score_genes.py | 11 ++-- src/scanpy/tools/_utils.py | 46 +++++++------ src/scanpy/tools/_utils_clustering.py | 11 ++-- src/testing/scanpy/_pytest/fixtures/data.py | 6 +- tests/test_metrics.py | 22 ++++--- tests/test_qc_metrics.py | 11 ++-- tests/test_utils.py | 6 +- 33 files changed, 275 insertions(+), 218 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e17aa9df0f..7306185eb2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -227,9 +227,6 @@ def setup(app: Sphinx): ("py:class", "scanpy._utils.Empty"), ("py:class", "numpy.random.mtrand.RandomState"), ("py:class", "scanpy.neighbors._types.KnnTransformerLike"), - # Will work once scipy 1.8 is released - ("py:class", "scipy.sparse.base.spmatrix"), - ("py:class", "scipy.sparse.csr.csr_matrix"), ] # Options for plot examples diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 326ea216d1..4965163f1e 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -49,17 +49,23 @@ else: from anndata._core.sparse_dataset import SparseDataset +_CSMatrix = sparse.csr_matrix | sparse.csc_matrix + if TYPE_CHECKING: from collections.abc import Callable, Iterable, KeysView, Mapping from pathlib import Path from typing import Any, TypeVar from anndata import AnnData + from igraph import Graph from numpy.typing import ArrayLike, DTypeLike, NDArray from .._compat import _LegacyRandom from ..neighbors import NeighborsParams, RPForestDict + _MemoryArray = NDArray | _CSMatrix + _SupportedArray = _MemoryArray | DaskArray + SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence RNGLike = np.random.Generator | np.random.BitGenerator @@ -282,7 +288,7 @@ def _check_use_raw( # -------------------------------------------------------------------------------- -def get_igraph_from_adjacency(adjacency, directed=None): +def get_igraph_from_adjacency(adjacency: _CSMatrix, *, directed: bool = False) -> Graph: """Get igraph graph from adjacency matrix.""" import igraph as ig @@ -564,21 +570,16 @@ def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]: # -------------------------------------------------------------------------------- -if TYPE_CHECKING: - _SparseMatrix = sparse.csr_matrix | sparse.csc_matrix - _MemoryArray = NDArray | _SparseMatrix - _SupportedArray = _MemoryArray | DaskArray - - @singledispatch def elem_mul(x: _SupportedArray, y: _SupportedArray) -> _SupportedArray: raise NotImplementedError @elem_mul.register(np.ndarray) -@elem_mul.register(sparse.spmatrix) +@elem_mul.register(sparse.csc_matrix) +@elem_mul.register(sparse.csr_matrix) def _elem_mul_in_mem(x: _MemoryArray, y: _MemoryArray) -> _MemoryArray: - if isinstance(x, sparse.spmatrix): + if isinstance(x, _CSMatrix): # returns coo_matrix, so cast back to input type return type(x)(x.multiply(y)) return x * y @@ -630,14 +631,14 @@ def axis_mul_or_truediv( @axis_mul_or_truediv.register(sparse.csr_matrix) @axis_mul_or_truediv.register(sparse.csc_matrix) def _( - X: sparse.csr_matrix | sparse.csc_matrix, + X: _CSMatrix, scaling_array, axis: Literal[0, 1], op: Callable[[Any, Any], Any], *, allow_divide_by_zero: bool = True, - out: sparse.csr_matrix | sparse.csc_matrix | None = None, -) -> sparse.csr_matrix | sparse.csc_matrix: + out: _CSMatrix | None = None, +) -> _CSMatrix: check_op(op) if out is not None and X.data is not out.data: msg = "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling." @@ -769,13 +770,22 @@ def axis_sum( ) -> np.matrix: ... -@singledispatch +@overload def axis_sum( X: np.ndarray, *, axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, dtype: DTypeLike | None = None, -) -> np.ndarray: +) -> np.ndarray: ... + + +@singledispatch +def axis_sum( + X: np.ndarray | sparse.spmatrix, + *, + axis: tuple[Literal[0, 1], ...] | Literal[0, 1] | None = None, + dtype: DTypeLike | None = None, +) -> np.ndarray | np.matrix: return np.sum(X, axis=axis, dtype=dtype) @@ -830,7 +840,8 @@ def check_nonnegative_integers(X: _SupportedArray) -> bool | DaskArray: @check_nonnegative_integers.register(np.ndarray) -@check_nonnegative_integers.register(sparse.spmatrix) +@check_nonnegative_integers.register(sparse.csr_matrix) +@check_nonnegative_integers.register(sparse.csc_matrix) def _check_nonnegative_integers_in_mem(X: _MemoryArray) -> bool: from numbers import Integral @@ -1132,8 +1143,10 @@ def __contains__(self, key: str) -> bool: return key in self._neighbors_dict -def _choose_graph(adata, obsp, neighbors_key): - """Choose connectivities from neighbbors or another obsp column""" +def _choose_graph( + adata: AnnData, obsp: str | None, neighbors_key: str | None +) -> _CSMatrix: + """Choose connectivities from neighbbors or another obsp entry.""" if obsp is not None and neighbors_key is not None: msg = "You can't specify both obsp, neighbors_key. Please select only one." raise ValueError(msg) diff --git a/src/scanpy/_utils/compute/is_constant.py b/src/scanpy/_utils/compute/is_constant.py index c9fac4abf0..43da127fb7 100644 --- a/src/scanpy/_utils/compute/is_constant.py +++ b/src/scanpy/_utils/compute/is_constant.py @@ -1,9 +1,8 @@ from __future__ import annotations -from collections.abc import Callable from functools import partial, singledispatch, wraps from numbers import Integral -from typing import TYPE_CHECKING, TypeVar, overload +from typing import TYPE_CHECKING, overload import numba import numpy as np @@ -12,11 +11,16 @@ from ..._compat import DaskArray, njit if TYPE_CHECKING: - from typing import Literal + from collections.abc import Callable + from typing import Literal, TypeVar from numpy.typing import NDArray -C = TypeVar("C", bound=Callable) + from ..._utils import _CSMatrix + + _Array = NDArray | DaskArray | _CSMatrix + + C = TypeVar("C", bound=Callable) def _check_axis_supported(wrapped: C) -> C: @@ -35,11 +39,9 @@ def func(a, axis=None): @overload -def is_constant(a: NDArray, axis: None = None) -> bool: ... - - +def is_constant(a: _Array, axis: None = None) -> bool: ... @overload -def is_constant(a: NDArray, axis: Literal[0, 1]) -> NDArray[np.bool_]: ... +def is_constant(a: _Array, axis: Literal[0, 1]) -> NDArray[np.bool_]: ... @_check_axis_supported diff --git a/src/scanpy/external/tl/_harmony_timeseries.py b/src/scanpy/external/tl/_harmony_timeseries.py index de3f8cde26..5cbfd2d856 100644 --- a/src/scanpy/external/tl/_harmony_timeseries.py +++ b/src/scanpy/external/tl/_harmony_timeseries.py @@ -76,9 +76,9 @@ def harmony_timeseries( **X_harmony** - :class:`~numpy.ndarray` (:attr:`~anndata.AnnData.obsm`, dtype `float`) force directed layout - **harmony_aff** - :class:`~scipy.sparse.spmatrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`) + **harmony_aff** - :class:`~scipy.sparse.csr_matrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`) affinity matrix - **harmony_aff_aug** - :class:`~scipy.sparse.spmatrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`) + **harmony_aff_aug** - :class:`~scipy.sparse.csr_matrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`) augmented affinity matrix **harmony_timepoint_var** - `str` (:attr:`~anndata.AnnData.uns`) The name of the variable passed as `tp` diff --git a/src/scanpy/external/tl/_palantir.py b/src/scanpy/external/tl/_palantir.py index eb060bbbe0..bb232d9d5f 100644 --- a/src/scanpy/external/tl/_palantir.py +++ b/src/scanpy/external/tl/_palantir.py @@ -94,7 +94,7 @@ def palantir( Array of Diffusion components. - palantir_EigenValues - :class:`~numpy.ndarray` (:attr:`~anndata.AnnData.uns`, dtype `float`) Array of corresponding eigen values. - - palantir_diff_op - :class:`~scipy.sparse.spmatrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`) + - palantir_diff_op - :class:`~scipy.sparse.csr_matrix` (:attr:`~anndata.AnnData.obsp`, dtype `float`) The diffusion operator matrix. **Multi scale space results**, diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 94bf202b69..aa7f93d9ce 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -338,9 +338,10 @@ def aggregate_df(data, by, func, *, mask=None, dof=1): @_aggregate.register(np.ndarray) -@_aggregate.register(sparse.spmatrix) +@_aggregate.register(sparse.csr_matrix) +@_aggregate.register(sparse.csc_matrix) def aggregate_array( - data, + data: Array, by: pd.Categorical, func: AggType | Iterable[AggType], *, diff --git a/src/scanpy/get/get.py b/src/scanpy/get/get.py index abfa51d1f9..aea8be9eb2 100644 --- a/src/scanpy/get/get.py +++ b/src/scanpy/get/get.py @@ -9,7 +9,8 @@ from anndata import AnnData from numpy.typing import NDArray from packaging.version import Version -from scipy.sparse import spmatrix + +from .._utils import _CSMatrix if TYPE_CHECKING: from collections.abc import Collection, Iterable @@ -17,11 +18,9 @@ from anndata._core.sparse_dataset import BaseCompressedSparseDataset from anndata._core.views import ArrayView - from scipy.sparse import csc_matrix, csr_matrix from .._compat import DaskArray - CSMatrix = csr_matrix | csc_matrix # -------------------------------------------------------------------------------- # Plotting data helpers @@ -334,7 +333,7 @@ def obs_df( val = adata.obsm[k] if isinstance(val, np.ndarray): df[added_k] = np.ravel(val[:, idx]) - elif isinstance(val, spmatrix): + elif isinstance(val, _CSMatrix): df[added_k] = np.ravel(val[:, idx].toarray()) elif isinstance(val, pd.DataFrame): df[added_k] = val.loc[:, idx] @@ -404,7 +403,7 @@ def var_df( val = adata.varm[k] if isinstance(val, np.ndarray): df[added_k] = np.ravel(val[:, idx]) - elif isinstance(val, spmatrix): + elif isinstance(val, _CSMatrix): df[added_k] = np.ravel(val[:, idx].toarray()) elif isinstance(val, pd.DataFrame): df[added_k] = val.loc[:, idx] @@ -420,7 +419,7 @@ def _get_obs_rep( obsp: str | None = None, ) -> ( np.ndarray - | spmatrix + | _CSMatrix | pd.DataFrame | ArrayView | BaseCompressedSparseDataset @@ -497,7 +496,7 @@ def _set_obs_rep( def _check_mask( - data: AnnData | np.ndarray | CSMatrix | DaskArray, + data: AnnData | np.ndarray | _CSMatrix | DaskArray, mask: str | M, dim: Literal["obs", "var"], *, diff --git a/src/scanpy/metrics/_common.py b/src/scanpy/metrics/_common.py index 5f2d37183b..804fd8eab2 100644 --- a/src/scanpy/metrics/_common.py +++ b/src/scanpy/metrics/_common.py @@ -2,7 +2,7 @@ import warnings from functools import singledispatch -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, overload import numpy as np import pandas as pd @@ -11,38 +11,49 @@ from .._compat import DaskArray if TYPE_CHECKING: + from typing import NoReturn, TypeVar + from numpy.typing import NDArray + T_NonSparse = TypeVar("T_NonSparse", bound=NDArray | DaskArray) + V = TypeVar("V", bound=NDArray | sparse.csr_matrix) + + +@overload +def _resolve_vals(val: T_NonSparse) -> T_NonSparse: ... +@overload +def _resolve_vals(val: sparse.spmatrix) -> sparse.csr_matrix: ... +@overload +def _resolve_vals(val: pd.DataFrame | pd.Series) -> NDArray: ... + @singledispatch -def _resolve_vals(val: NDArray | sparse.spmatrix) -> NDArray | sparse.csr_matrix: - return np.asarray(val) +def _resolve_vals(val: object) -> NoReturn: + msg = f"Unsupported type {type(val)}" + raise TypeError(msg) @_resolve_vals.register(np.ndarray) @_resolve_vals.register(sparse.csr_matrix) @_resolve_vals.register(DaskArray) -def _(val): +def _( + val: np.ndarray | sparse.csr_matrix | DaskArray, +) -> np.ndarray | sparse.csr_matrix | DaskArray: return val @_resolve_vals.register(sparse.spmatrix) -def _(val): +def _(val: sparse.spmatrix) -> sparse.csr_matrix: return sparse.csr_matrix(val) @_resolve_vals.register(pd.DataFrame) @_resolve_vals.register(pd.Series) -def _(val): +def _(val: pd.DataFrame | pd.Series) -> NDArray: return val.to_numpy() -V = TypeVar("V", np.ndarray, sparse.csr_matrix) - - -def _check_vals( - vals: V, -) -> tuple[V, NDArray[np.bool_] | slice, NDArray[np.float64]]: +def _check_vals(vals: V) -> tuple[V, NDArray[np.bool_] | slice, NDArray[np.float64]]: """\ Checks that values wont cause issues in computation. diff --git a/src/scanpy/metrics/_gearys_c.py b/src/scanpy/metrics/_gearys_c.py index cf4220eb7a..33b77b4c63 100644 --- a/src/scanpy/metrics/_gearys_c.py +++ b/src/scanpy/metrics/_gearys_c.py @@ -15,13 +15,16 @@ if TYPE_CHECKING: from anndata import AnnData + from numpy.typing import NDArray + + from .._compat import DaskArray @singledispatch def gearys_c( adata: AnnData, *, - vals: np.ndarray | sparse.spmatrix | None = None, + vals: NDArray | sparse.spmatrix | DaskArray | None = None, use_graph: str | None = None, layer: str | None = None, obsm: str | None = None, @@ -290,7 +293,9 @@ def _gearys_c_mtx_csr( # noqa: PLR0917 @gearys_c.register(sparse.csr_matrix) -def _gearys_c(g: sparse.csr_matrix, vals: np.ndarray | sparse.spmatrix) -> np.ndarray: +def _gearys_c( + g: sparse.csr_matrix, vals: NDArray | sparse.spmatrix | DaskArray +) -> np.ndarray: assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix" vals = _resolve_vals(vals) g_data = g.data.astype(np.float64, copy=False) diff --git a/src/scanpy/metrics/_morans_i.py b/src/scanpy/metrics/_morans_i.py index c21c455f38..516b1ce616 100644 --- a/src/scanpy/metrics/_morans_i.py +++ b/src/scanpy/metrics/_morans_i.py @@ -15,13 +15,16 @@ if TYPE_CHECKING: from anndata import AnnData + from numpy.typing import NDArray + + from .._compat import DaskArray @singledispatch def morans_i( adata: AnnData, *, - vals: np.ndarray | sparse.spmatrix | None = None, + vals: NDArray | sparse.spmatrix | DaskArray | None = None, use_graph: str | None = None, layer: str | None = None, obsm: str | None = None, @@ -226,7 +229,9 @@ def _morans_i_mtx_csr( # noqa: PLR0917 @morans_i.register(sparse.csr_matrix) -def _morans_i(g: sparse.csr_matrix, vals: np.ndarray | sparse.spmatrix) -> np.ndarray: +def _morans_i( + g: sparse.csr_matrix, vals: NDArray | sparse.spmatrix | DaskArray +) -> np.ndarray: assert g.shape[0] == g.shape[1], "`g` should be a square adjacency matrix" vals = _resolve_vals(vals) g_data = g.data.astype(np.float64, copy=False) diff --git a/src/scanpy/neighbors/_types.py b/src/scanpy/neighbors/_types.py index 39f50284ec..397f6f38c6 100644 --- a/src/scanpy/neighbors/_types.py +++ b/src/scanpy/neighbors/_types.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from typing import Any, Self - from scipy.sparse import spmatrix + from scipy.sparse import csr_matrix # These two are used with get_literal_vals elsewhere @@ -49,10 +49,10 @@ class KnnTransformerLike(Protocol): """See :class:`~sklearn.neighbors.KNeighborsTransformer`.""" def fit(self, X, y: None = None): ... - def transform(self, X) -> spmatrix: ... + def transform(self, X) -> csr_matrix: ... # from TransformerMixin - def fit_transform(self, X, y: None = None) -> spmatrix: ... + def fit_transform(self, X, y: None = None) -> csr_matrix: ... # from BaseEstimator def get_params(self, *, deep: bool = True) -> dict[str, Any]: ... diff --git a/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py b/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py index bba4fb9bbf..f841e24da3 100644 --- a/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py +++ b/src/scanpy/preprocessing/_deprecated/highly_variable_genes.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from typing import Literal - from scipy.sparse import spmatrix + from .._utils import _CSMatrix @deprecated("Use sc.pp.highly_variable_genes instead") @@ -33,7 +33,7 @@ "copy", ) def filter_genes_dispersion( - data: AnnData | spmatrix | np.ndarray, + data: AnnData | _CSMatrix | np.ndarray, *, flavor: Literal["seurat", "cell_ranger"] = "seurat", min_disp: float | None = None, diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py index 02619a2364..4be071fc02 100644 --- a/src/scanpy/preprocessing/_deprecated/sampling.py +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -9,22 +9,20 @@ import numpy as np from anndata import AnnData from numpy.typing import NDArray - from scipy.sparse import csc_matrix, csr_matrix from ..._compat import _LegacyRandom - - CSMatrix = csr_matrix | csc_matrix + from ..._utils import _CSMatrix @old_positionals("n_obs", "random_state", "copy") def subsample( - data: AnnData | np.ndarray | CSMatrix, + data: AnnData | np.ndarray | _CSMatrix, fraction: float | None = None, *, n_obs: int | None = None, random_state: _LegacyRandom = 0, copy: bool = False, -) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None: +) -> AnnData | tuple[np.ndarray | _CSMatrix, NDArray[np.int64]] | None: """\ Subsample to a fraction of the number of observations. diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 356fa8f03f..c5dc3a27a2 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -103,7 +103,7 @@ def _highly_variable_genes_seurat_v3( vmax = np.sqrt(N) clip_val = reg_std * vmax + mean if sp_sparse.issparse(data_batch): - if sp_sparse.isspmatrix_csr(data_batch): + if isinstance(data_batch, sp_sparse.csr_matrix): batch_counts = data_batch else: batch_counts = sp_sparse.csr_matrix(data_batch) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index db7886a29f..77c5c6b3fe 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -27,13 +27,9 @@ import dask_ml.decomposition as dmld import sklearn.decomposition as skld from numpy.typing import DTypeLike, NDArray - from scipy import sparse - from scipy.sparse import spmatrix from ..._compat import _LegacyRandom - from ..._utils import Empty - - CSMatrix = sparse.csr_matrix | sparse.csc_matrix + from ..._utils import Empty, _CSMatrix MethodDaskML = type[dmld.PCA | dmld.IncrementalPCA | dmld.TruncatedSVD] MethodSklearn = type[skld.PCA | skld.TruncatedSVD] @@ -65,7 +61,7 @@ mask_var_hvg=doc_mask_var_hvg, ) def pca( - data: AnnData | np.ndarray | spmatrix, + data: AnnData | np.ndarray | _CSMatrix, n_comps: int | None = None, *, layer: str | None = None, @@ -80,7 +76,7 @@ def pca( chunk_size: int | None = None, key_added: str | None = None, copy: bool = False, -) -> AnnData | np.ndarray | spmatrix | None: +) -> AnnData | np.ndarray | _CSMatrix | None: """\ Principal component analysis :cite:p:`Pedregosa2011`. @@ -195,7 +191,7 @@ def pca( Otherwise, it returns `None` if `copy=False`, else an updated `AnnData` object. Sets the following fields: - `.obsm['X_pca' | key_added]` : :class:`~scipy.sparse.spmatrix` | :class:`~numpy.ndarray` (shape `(adata.n_obs, n_comps)`) + `.obsm['X_pca' | key_added]` : :class:`~scipy.sparse.csr_matrix` | :class:`~scipy.sparse.csc_matrix` | :class:`~numpy.ndarray` (shape `(adata.n_obs, n_comps)`) PCA representation of data. `.varm['PCs' | key_added]` : :class:`~numpy.ndarray` (shape `(adata.n_vars, n_comps)`) The principal components containing the loadings. @@ -218,8 +214,7 @@ def pca( "reproducible across different computational platforms. For exact " "reproducibility, choose `svd_solver='arpack'`." ) - data_is_AnnData = isinstance(data, AnnData) - if data_is_AnnData: + if return_anndata := isinstance(data, AnnData): if layer is None and not chunked and is_backed_type(data.X): msg = f"PCA is not implemented for matrices of type {type(data.X)} with chunked as False" raise NotImplementedError(msg) @@ -377,7 +372,7 @@ def pca( if X_pca.dtype.descr != np.dtype(dtype).descr: X_pca = X_pca.astype(dtype) - if data_is_AnnData: + if return_anndata: key_obsm, key_varm, key_uns = ( ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 ) diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index 28eef2ba1a..056650fd15 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -15,16 +15,14 @@ from typing import Literal from numpy.typing import NDArray - from scipy import sparse from sklearn.decomposition import PCA from ..._compat import _LegacyRandom - - CSMatrix = sparse.csr_matrix | sparse.csc_matrix + from .._utils import _CSMatrix def _pca_compat_sparse( - x: CSMatrix, + x: _CSMatrix, n_pcs: int, *, solver: Literal["arpack", "lobpcg"], diff --git a/src/scanpy/preprocessing/_pca/_dask_sparse.py b/src/scanpy/preprocessing/_pca/_dask_sparse.py index c2bff7ccca..7f53bda992 100644 --- a/src/scanpy/preprocessing/_pca/_dask_sparse.py +++ b/src/scanpy/preprocessing/_pca/_dask_sparse.py @@ -15,11 +15,9 @@ from typing import Literal from numpy.typing import DTypeLike - from scipy import sparse from ..._compat import DaskArray - - CSMatrix = sparse.csr_matrix | sparse.csc_matrix + from .._utils import _CSMatrix @dataclass @@ -120,7 +118,7 @@ def transform(self, x: DaskArray) -> DaskArray: import dask.array as da def transform_block( - x_part: CSMatrix, + x_part: _CSMatrix, mean_: NDArray[np.floating], components_: NDArray[np.floating], ): @@ -191,8 +189,8 @@ def _cov_sparse_dask( else: dtype = np.dtype(dtype) - def gram_block(x_part: CSMatrix): - gram_matrix: CSMatrix = x_part.T @ x_part + def gram_block(x_part: _CSMatrix): + gram_matrix: _CSMatrix = x_part.T @ x_part return gram_matrix.toarray()[None, ...] # need new axis for summing gram_matrix_dask: DaskArray = da.map_blocks( diff --git a/src/scanpy/preprocessing/_qc.py b/src/scanpy/preprocessing/_qc.py index 5af8def042..e57719a20d 100644 --- a/src/scanpy/preprocessing/_qc.py +++ b/src/scanpy/preprocessing/_qc.py @@ -1,19 +1,19 @@ from __future__ import annotations -from functools import singledispatch +from functools import singledispatch, wraps from typing import TYPE_CHECKING from warnings import warn import numba import numpy as np import pandas as pd -from scipy.sparse import csr_matrix, issparse, isspmatrix_coo, isspmatrix_csr, spmatrix +from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, issparse from scanpy.preprocessing._distributed import materialize_as_ndarray from scanpy.preprocessing._utils import _get_mean_var from .._compat import DaskArray, njit -from .._utils import _doc_params, axis_nnz, axis_sum +from .._utils import _CSMatrix, _doc_params, axis_nnz, axis_sum from ._docs import ( doc_adata_basic, doc_expr_reps, @@ -103,9 +103,9 @@ def describe_obs( # Handle whether X is passed if X is None: X = _choose_mtx_rep(adata, use_raw=use_raw, layer=layer) - if isspmatrix_coo(X): + if isinstance(X, coo_matrix): X = csr_matrix(X) # COO not subscriptable - if issparse(X): + if isinstance(X, _CSMatrix): X.eliminate_zeros() obs_metrics = pd.DataFrame(index=adata.obs_names) obs_metrics[f"n_{var_type}_by_{expr_type}"] = materialize_as_ndarray( @@ -162,7 +162,7 @@ def describe_var( use_raw: bool = False, inplace: bool = False, log1p: bool = True, - X: spmatrix | np.ndarray | None = None, + X: _CSMatrix | coo_matrix | np.ndarray | None = None, ) -> pd.DataFrame | None: """\ Describe variables of anndata. @@ -190,9 +190,9 @@ def describe_var( # Handle whether X is passed if X is None: X = _choose_mtx_rep(adata, use_raw=use_raw, layer=layer) - if isspmatrix_coo(X): + if isinstance(X, coo_matrix): X = csr_matrix(X) # COO not subscriptable - if issparse(X): + if isinstance(X, _CSMatrix): X.eliminate_zeros() var_metrics = pd.DataFrame(index=adata.var_names) var_metrics[f"n_cells_by_{expr_type}"], var_metrics[f"mean_{expr_type}"] = ( @@ -299,9 +299,9 @@ def calculate_qc_metrics( ) # Pass X so I only have to do it once X = _choose_mtx_rep(adata, use_raw=use_raw, layer=layer) - if isspmatrix_coo(X): + if isinstance(X, coo_matrix): X = csr_matrix(X) # COO not subscriptable - if issparse(X): + if isinstance(X, _CSMatrix): X.eliminate_zeros() # Convert qc_vars to list if str @@ -331,7 +331,7 @@ def calculate_qc_metrics( return obs_metrics, var_metrics -def top_proportions(mtx: np.ndarray | spmatrix, n: int): +def top_proportions(mtx: np.ndarray | _CSMatrix | coo_matrix, n: int): """\ Calculates cumulative proportions of top expressed genes @@ -345,7 +345,7 @@ def top_proportions(mtx: np.ndarray | spmatrix, n: int): expressed gene. """ if issparse(mtx): - if not isspmatrix_csr(mtx): + if not isinstance(mtx, csr_matrix): mtx = csr_matrix(mtx) # Allowing numba to do more return top_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(n)) @@ -383,7 +383,10 @@ def top_proportions_sparse_csr(data, indptr, n): def check_ns(func): - def check_ns_inner(mtx: np.ndarray | spmatrix | DaskArray, ns: Collection[int]): + @wraps(func) + def check_ns_inner( + mtx: np.ndarray | _CSMatrix | coo_matrix | DaskArray, ns: Collection[int] + ): if not (max(ns) <= mtx.shape[1] and min(ns) > 0): msg = "Positions outside range of features." raise IndexError(msg) @@ -434,10 +437,12 @@ def _(mtx: DaskArray, ns: Collection[int]) -> DaskArray: ).compute() -@top_segment_proportions.register(spmatrix) +@top_segment_proportions.register(csr_matrix) +@top_segment_proportions.register(csc_matrix) +@top_segment_proportions.register(coo_matrix) @check_ns -def _(mtx: spmatrix, ns: Collection[int]) -> DaskArray: - if not isspmatrix_csr(mtx): +def _(mtx: _CSMatrix | coo_matrix, ns: Collection[int]) -> DaskArray: + if not isinstance(mtx, csr_matrix): mtx = csr_matrix(mtx) return top_segment_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(ns)) diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index ee15f977b9..2a8ff2140e 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -8,7 +8,7 @@ import numba import numpy as np from anndata import AnnData -from scipy.sparse import issparse, isspmatrix_csc, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse from .. import logging as logg from .._compat import DaskArray, njit, old_positionals @@ -30,9 +30,8 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from scipy import sparse as sp - CSMatrix = sp.csr_matrix | sp.csc_matrix + from .._utils import _CSMatrix @njit @@ -66,7 +65,7 @@ def clip_array( return X -def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix: +def clip_set(x: _CSMatrix, *, max_value: float, zero_center: bool = True) -> _CSMatrix: x = x.copy() x[x > max_value] = max_value if zero_center: @@ -78,7 +77,7 @@ def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMa @old_positionals("zero_center", "max_value", "copy", "layer", "obsm") @singledispatch def scale( - data: AnnData | spmatrix | np.ndarray | DaskArray, + data: AnnData | _CSMatrix | np.ndarray | DaskArray, *, zero_center: bool = True, max_value: float | None = None, @@ -86,7 +85,7 @@ def scale( layer: str | None = None, obsm: str | None = None, mask_obs: NDArray[np.bool_] | str | None = None, -) -> AnnData | spmatrix | np.ndarray | DaskArray | None: +) -> AnnData | _CSMatrix | np.ndarray | DaskArray | None: """\ Scale data to unit variance and zero mean. @@ -228,9 +227,10 @@ def scale_array( return X -@scale.register(spmatrix) +@scale.register(csr_matrix) +@scale.register(csc_matrix) def scale_sparse( - X: spmatrix, + X: _CSMatrix, *, zero_center: bool = True, max_value: float | None = None, @@ -264,7 +264,7 @@ def scale_sparse( mask_obs=mask_obs, ) else: - if isspmatrix_csc(X): + if isinstance(X, csc_matrix): X = X.tocsr() elif copy: X = X.copy() diff --git a/src/scanpy/preprocessing/_scrublet/core.py b/src/scanpy/preprocessing/_scrublet/core.py index 1236f42a7a..8130f6bd3c 100644 --- a/src/scanpy/preprocessing/_scrublet/core.py +++ b/src/scanpy/preprocessing/_scrublet/core.py @@ -23,6 +23,7 @@ from ..._compat import _LegacyRandom from ...neighbors import _Metric, _MetricFn + from .._utils import _CSMatrix __all__ = ["Scrublet"] @@ -65,9 +66,7 @@ class Scrublet: # init fields - counts_obs: InitVar[sparse.csr_matrix | sparse.csc_matrix | NDArray[np.integer]] = ( - field(kw_only=False) - ) + counts_obs: InitVar[_CSMatrix | NDArray[np.integer]] = field(kw_only=False) total_counts_obs: InitVar[NDArray[np.integer] | None] = None sim_doublet_ratio: float = 2.0 n_neighbors: InitVar[int | None] = None @@ -82,15 +81,11 @@ class Scrublet: _counts_obs: sparse.csc_matrix = field(init=False, repr=False) _total_counts_obs: NDArray[np.integer] = field(init=False, repr=False) - _counts_obs_norm: sparse.csr_matrix | sparse.csc_matrix = field( - init=False, repr=False - ) + _counts_obs_norm: _CSMatrix = field(init=False, repr=False) - _counts_sim: sparse.csr_matrix | sparse.csc_matrix = field(init=False, repr=False) + _counts_sim: _CSMatrix = field(init=False, repr=False) _total_counts_sim: NDArray[np.integer] = field(init=False, repr=False) - _counts_sim_norm: sparse.csr_matrix | sparse.csc_matrix | None = field( - default=None, init=False, repr=False - ) + _counts_sim_norm: _CSMatrix | None = field(default=None, init=False, repr=False) # Fields set by methods @@ -171,7 +166,7 @@ class Scrublet: def __post_init__( self, - counts_obs: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.integer], + counts_obs: _CSMatrix | NDArray[np.integer], total_counts_obs: NDArray[np.integer] | None, n_neighbors: int | None, random_state: _LegacyRandom, diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index 795559583c..c570612c0a 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -12,13 +12,14 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from .._compat import _LegacyRandom + from ..._compat import _LegacyRandom + from ..._utils import _CSMatrix def sparse_multiply( - E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64], + E: _CSMatrix | NDArray[np.float64], a: float | NDArray[np.float64], -) -> sparse.csr_matrix | sparse.csc_matrix: +) -> _CSMatrix: """multiply each row of E by a scalar""" nrow = E.shape[0] @@ -30,11 +31,11 @@ def sparse_multiply( def sparse_zscore( - E: sparse.csr_matrix | sparse.csc_matrix, + E: _CSMatrix, *, gene_mean: NDArray[np.float64] | None = None, gene_stdev: NDArray[np.float64] | None = None, -) -> sparse.csr_matrix | sparse.csc_matrix: +) -> _CSMatrix: """z-score normalize each column of E""" if gene_mean is None or gene_stdev is None: gene_means, gene_stdevs = _get_mean_var(E, axis=0) @@ -43,12 +44,12 @@ def sparse_zscore( def subsample_counts( - E: sparse.csr_matrix | sparse.csc_matrix, + E: _CSMatrix, *, rate: float, original_totals, random_seed: _LegacyRandom = 0, -) -> tuple[sparse.csr_matrix | sparse.csc_matrix, NDArray[np.int64]]: +) -> tuple[_CSMatrix, NDArray[np.int64]]: if rate < 1: random_seed = _get_legacy_random(random_seed) E.data = random_seed.binomial(np.round(E.data).astype(int), rate) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index fda79b4da2..986a2cd386 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -14,7 +14,7 @@ import numpy as np from anndata import AnnData from pandas.api.types import CategoricalDtype -from scipy.sparse import csc_matrix, csr_matrix, issparse, isspmatrix_csr, spmatrix +from scipy.sparse import csc_matrix, csr_matrix, issparse from sklearn.utils import check_array, sparsefuncs from .. import logging as logg @@ -22,6 +22,7 @@ from .._settings import settings as sett from .._utils import ( _check_array_function_arguments, + _CSMatrix, _resolve_axis, axis_sum, is_backed_type, @@ -51,16 +52,14 @@ from .._utils import RNGLike, SeedLike -CSMatrix = csr_matrix | csc_matrix - -A = TypeVar("A", bound=np.ndarray | CSMatrix | DaskArray) +A = TypeVar("A", bound=np.ndarray | _CSMatrix | DaskArray) @old_positionals( "min_counts", "min_genes", "max_counts", "max_genes", "inplace", "copy" ) def filter_cells( - data: AnnData | spmatrix | np.ndarray | DaskArray, + data: AnnData | _CSMatrix | np.ndarray | DaskArray, *, min_counts: int | None = None, min_genes: int | None = None, @@ -209,7 +208,7 @@ def filter_cells( "min_counts", "min_cells", "max_counts", "max_cells", "inplace", "copy" ) def filter_genes( - data: AnnData | spmatrix | np.ndarray | DaskArray, + data: AnnData | _CSMatrix | np.ndarray | DaskArray, *, min_counts: int | None = None, min_cells: int | None = None, @@ -322,7 +321,7 @@ def filter_genes( @renamed_arg("X", "data", pos_0=True) @singledispatch def log1p( - data: AnnData | np.ndarray | spmatrix, + data: AnnData | np.ndarray | _CSMatrix, *, base: Number | None = None, copy: bool = False, @@ -330,7 +329,7 @@ def log1p( chunk_size: int | None = None, layer: str | None = None, obsm: str | None = None, -) -> AnnData | np.ndarray | spmatrix | None: +) -> AnnData | np.ndarray | _CSMatrix | None: """\ Logarithmize the data matrix. @@ -367,8 +366,9 @@ def log1p( return log1p_array(data, copy=copy, base=base) -@log1p.register(spmatrix) -def log1p_sparse(X: spmatrix, *, base: Number | None = None, copy: bool = False): +@log1p.register(csr_matrix) +@log1p.register(csc_matrix) +def log1p_sparse(X: _CSMatrix, *, base: Number | None = None, copy: bool = False): X = check_array( X, accept_sparse=("csr", "csc"), dtype=(np.float64, np.float32), copy=copy ) @@ -437,12 +437,12 @@ def log1p_anndata( @old_positionals("copy", "chunked", "chunk_size") def sqrt( - data: AnnData | spmatrix | np.ndarray, + data: AnnData | _CSMatrix | np.ndarray, *, copy: bool = False, chunked: bool = False, chunk_size: int | None = None, -) -> AnnData | spmatrix | np.ndarray | None: +) -> AnnData | _CSMatrix | np.ndarray | None: """\ Square root the data matrix. @@ -492,7 +492,7 @@ def sqrt( "min_counts", ) def normalize_per_cell( - data: AnnData | np.ndarray | spmatrix, + data: AnnData | np.ndarray | _CSMatrix, *, counts_per_cell_after: float | None = None, counts_per_cell: np.ndarray | None = None, @@ -501,7 +501,7 @@ def normalize_per_cell( layers: Literal["all"] | Iterable[str] = (), use_rep: Literal["after", "X"] | None = None, min_counts: int = 1, -) -> AnnData | np.ndarray | spmatrix | None: +) -> AnnData | np.ndarray | _CSMatrix | None: """\ Normalize total counts per cell. @@ -872,7 +872,7 @@ def sample( p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, ) -> tuple[A, NDArray[np.int64]]: ... def sample( - data: AnnData | np.ndarray | CSMatrix | DaskArray, + data: AnnData | np.ndarray | _CSMatrix | DaskArray, fraction: float | None = None, *, n: int | None = None, @@ -881,7 +881,7 @@ def sample( replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", p: str | NDArray[np.bool_] | NDArray[np.floating] | None = None, -) -> AnnData | None | tuple[np.ndarray | CSMatrix | DaskArray, NDArray[np.int64]]: +) -> AnnData | None | tuple[np.ndarray | _CSMatrix | DaskArray, NDArray[np.int64]]: """\ Sample observations or variables with or without replacement. @@ -970,7 +970,7 @@ def sample( return subset.to_memory() if data.isbacked else subset.copy() # overload 3: return array and indices - assert isinstance(subset, np.ndarray | CSMatrix | DaskArray), type(subset) + assert isinstance(subset, np.ndarray | _CSMatrix | DaskArray), type(subset) if copy: subset = subset.copy() return subset, indices @@ -1016,7 +1016,7 @@ def downsample_counts( ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: - `adata.X` : :class:`numpy.ndarray` | :class:`scipy.sparse.spmatrix` (dtype `float`) + `adata.X` : :class:`~numpy.ndarray` | :class:`~scipy.sparse.csr_matrix` | :class:`~scipy.sparse.csc_matrix` (dtype `float`) Downsampled counts matrix. """ raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") @@ -1029,14 +1029,24 @@ def downsample_counts( if copy: adata = adata.copy() if total_counts_call: - adata.X = _downsample_total_counts(adata.X, total_counts, random_state, replace) + adata.X = _downsample_total_counts( + adata.X, total_counts, random_state=random_state, replace=replace + ) elif counts_per_cell_call: - adata.X = _downsample_per_cell(adata.X, counts_per_cell, random_state, replace) + adata.X = _downsample_per_cell( + adata.X, counts_per_cell, random_state=random_state, replace=replace + ) if copy: return adata -def _downsample_per_cell(X, counts_per_cell, random_state, replace): +def _downsample_per_cell( + X: _CSMatrix, + counts_per_cell: int, + *, + random_state: _LegacyRandom, + replace: bool, +) -> _CSMatrix: n_obs = X.shape[0] if isinstance(counts_per_cell, int): counts_per_cell = np.full(n_obs, counts_per_cell) @@ -1053,7 +1063,7 @@ def _downsample_per_cell(X, counts_per_cell, random_state, replace): raise ValueError(msg) if issparse(X): original_type = type(X) - if not isspmatrix_csr(X): + if not isinstance(X, csr_matrix): X = csr_matrix(X) totals = np.ravel(axis_sum(X, axis=1)) # Faster for csr matrix under_target = np.nonzero(totals > counts_per_cell)[0] @@ -1085,14 +1095,20 @@ def _downsample_per_cell(X, counts_per_cell, random_state, replace): return X -def _downsample_total_counts(X, total_counts, random_state, replace): +def _downsample_total_counts( + X: _CSMatrix, + total_counts: int, + *, + random_state: _LegacyRandom, + replace: bool, +) -> _CSMatrix: total_counts = int(total_counts) total = X.sum() if total < total_counts: return X if issparse(X): original_type = type(X) - if not isspmatrix_csr(X): + if not isinstance(X, csr_matrix): X = csr_matrix(X) _downsample_array( X.data, diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 3ca74734c0..755a687d67 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -9,7 +9,7 @@ from sklearn.random_projection import sample_without_replacement from .._compat import njit -from .._utils import axis_sum, elem_mul +from .._utils import _CSMatrix, axis_sum, elem_mul if TYPE_CHECKING: from typing import Literal @@ -34,8 +34,11 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray: def _get_mean_var( X: _SupportedArray, *, axis: Literal[0, 1] = 0 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - if isinstance(X, sparse.spmatrix): + if isinstance(X, _CSMatrix): mean, var = sparse_mean_variance_axis(X, axis=axis) + elif isinstance(X, sparse.spmatrix): + msg = f"Unsupported type {type(X)}" + raise TypeError(msg) else: mean = axis_mean(X, axis=axis, dtype=np.float64) mean_sq = axis_mean(elem_mul(X, X), axis=axis, dtype=np.float64) @@ -46,7 +49,9 @@ def _get_mean_var( return mean, var -def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int): +def sparse_mean_variance_axis( + mtx: _CSMatrix, axis: Literal[0, 1] +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """ This code and internal functions are based on sklearns `sparsefuncs.mean_variance_axis`. @@ -163,10 +168,7 @@ def sample_comb( return np.vstack(np.unravel_index(idx, dims)).T -def _to_dense( - X: sparse.spmatrix, - order: Literal["C", "F"] = "C", -) -> NDArray: +def _to_dense(X: _CSMatrix, order: Literal["C", "F"] = "C") -> NDArray: """\ Numba kernel for np.toarray() function """ diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index cccd7f96ce..97d1cea1b5 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -15,9 +15,10 @@ from typing import Literal from anndata import AnnData - from scipy import sparse from .._compat import _LegacyRandom + from .._utils import _CSMatrix + try: from leidenalg.VertexPartition import MutableVertexPartition @@ -36,7 +37,7 @@ def leiden( restrict_to: tuple[str, Sequence[str]] | None = None, random_state: _LegacyRandom = 0, key_added: str = "leiden", - adjacency: sparse.spmatrix | None = None, + adjacency: _CSMatrix | None = None, directed: bool | None = None, use_weights: bool = True, n_iterations: int = -1, diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index 1800dbe08e..88269a21d4 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -20,9 +20,9 @@ from typing import Any, Literal from anndata import AnnData - from scipy.sparse import spmatrix from .._compat import _LegacyRandom + from .._utils import _CSMatrix try: from louvain.VertexPartition import MutableVertexPartition @@ -55,7 +55,7 @@ def louvain( random_state: _LegacyRandom = 0, restrict_to: tuple[str, Sequence[str]] | None = None, key_added: str = "louvain", - adjacency: spmatrix | None = None, + adjacency: _CSMatrix | None = None, flavor: Literal["vtraag", "igraph", "rapids"] = "vtraag", directed: bool = True, use_weights: bool = False, diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 05e5738d99..206920fe87 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -24,10 +24,12 @@ from anndata import AnnData from numpy.typing import NDArray - from scipy import sparse + + from .._utils import _CSMatrix _CorrMethod = Literal["benjamini-hochberg", "bonferroni"] + # Used with get_literal_vals _Method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"] @@ -45,7 +47,7 @@ def _select_top_n(scores: NDArray, n_top: int): def _ranks( - X: np.ndarray | sparse.csr_matrix | sparse.csc_matrix, + X: np.ndarray | _CSMatrix, mask_obs: NDArray[np.bool_] | None = None, mask_obs_rest: NDArray[np.bool_] | None = None, ) -> Generator[tuple[pd.DataFrame, int, int], None, None]: diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 7dff1a300c..a67331e678 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -8,10 +8,9 @@ import pandas as pd from scipy.sparse import issparse -from scanpy._utils import _check_use_raw, is_backed_type - from .. import logging as logg from .._compat import old_positionals +from .._utils import _check_use_raw, is_backed_type from ..get import _get_obs_rep if TYPE_CHECKING: @@ -20,20 +19,18 @@ from anndata import AnnData from numpy.typing import DTypeLike, NDArray - from scipy.sparse import csc_matrix, csr_matrix from .._compat import _LegacyRandom + from .._utils import _CSMatrix try: _StrIdx = pd.Index[str] except TypeError: # Sphinx _StrIdx = pd.Index - _GetSubset = Callable[[_StrIdx], np.ndarray | csr_matrix | csc_matrix] + _GetSubset = Callable[[_StrIdx], np.ndarray | _CSMatrix] -def _sparse_nanmean( - X: csr_matrix | csc_matrix, axis: Literal[0, 1] -) -> NDArray[np.float64]: +def _sparse_nanmean(X: _CSMatrix, axis: Literal[0, 1]) -> NDArray[np.float64]: """ np.nanmean equivalent for sparse matrices """ diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index 4d24b5e276..e1a21cfac3 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from anndata import AnnData - from scipy.sparse import csr_matrix + from scipy.sparse import csr_matrix, spmatrix def _choose_representation( @@ -106,29 +106,33 @@ def preprocess_with_pca(adata, n_pcs: int | None = None, random_state=0): def get_init_pos_from_paga( - adata, adjacency=None, random_state=0, neighbors_key=None, obsp=None + adata: AnnData, + adjacency: spmatrix | None = None, + random_state=0, + neighbors_key: str | None = None, + obsp: str | None = None, ): np.random.seed(random_state) if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) - if "paga" in adata.uns and "pos" in adata.uns["paga"]: - groups = adata.obs[adata.uns["paga"]["groups"]] - pos = adata.uns["paga"]["pos"] - connectivities_coarse = adata.uns["paga"]["connectivities"] - init_pos = np.ones((adjacency.shape[0], 2)) - for i, group_pos in enumerate(pos): - subset = (groups == groups.cat.categories[i]).values - neighbors = connectivities_coarse[i].nonzero() - if len(neighbors[1]) > 0: - connectivities = connectivities_coarse[i][neighbors] - nearest_neighbor = neighbors[1][np.argmax(connectivities)] - noise = np.random.random((len(subset[subset]), 2)) - dist = pos[i] - pos[nearest_neighbor] - noise = noise * dist - init_pos[subset] = group_pos - 0.5 * dist + noise - else: - init_pos[subset] = group_pos - else: - msg = "Plot PAGA first, so that adata.uns['paga'] with key 'pos'." + if "pos" not in adata.uns.get("paga", {}): + msg = "Plot PAGA first, so that `adata.uns['paga']['pos']` exists." raise ValueError(msg) + + groups = adata.obs[adata.uns["paga"]["groups"]] + pos = adata.uns["paga"]["pos"] + connectivities_coarse = adata.uns["paga"]["connectivities"] + init_pos = np.ones((adjacency.shape[0], 2)) + for i, group_pos in enumerate(pos): + subset = (groups == groups.cat.categories[i]).values + neighbors = connectivities_coarse[i].nonzero() + if len(neighbors[1]) > 0: + connectivities = connectivities_coarse[i][neighbors] + nearest_neighbor = neighbors[1][np.argmax(connectivities)] + noise = np.random.random((len(subset[subset]), 2)) + dist = pos[i] - pos[nearest_neighbor] + noise = noise * dist + init_pos[subset] = group_pos - 0.5 * dist + noise + else: + init_pos[subset] = group_pos return init_pos diff --git a/src/scanpy/tools/_utils_clustering.py b/src/scanpy/tools/_utils_clustering.py index 3c771e5d74..03186f86bb 100644 --- a/src/scanpy/tools/_utils_clustering.py +++ b/src/scanpy/tools/_utils_clustering.py @@ -3,13 +3,14 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Sequence import numpy as np import pandas as pd from anndata import AnnData from numpy.typing import NDArray - from scipy.sparse import spmatrix + + from .._utils import _CSMatrix def rename_groups( @@ -33,9 +34,9 @@ def restrict_adjacency( adata: AnnData, restrict_key: str, *, - restrict_categories: Iterable[str], - adjacency: spmatrix, -) -> tuple[spmatrix, NDArray[np.bool_]]: + restrict_categories: Sequence[str], + adjacency: _CSMatrix, +) -> tuple[_CSMatrix, NDArray[np.bool_]]: if not isinstance(restrict_categories[0], str): msg = "You need to use strings to label categories, e.g. '1' instead of 1." raise ValueError(msg) diff --git a/src/testing/scanpy/_pytest/fixtures/data.py b/src/testing/scanpy/_pytest/fixtures/data.py index 4d44d8239b..bd316bad09 100644 --- a/src/testing/scanpy/_pytest/fixtures/data.py +++ b/src/testing/scanpy/_pytest/fixtures/data.py @@ -36,6 +36,8 @@ def make_sparse(x): from numpy.typing import DTypeLike + _CSMatrix = sparse.csr_matrix | sparse.csc_matrix + @pytest.fixture( scope="session", @@ -97,9 +99,7 @@ def backed_adata( def _prepare_pbmc_testdata( adata: AnnData, - sparsity_func: Callable[ - [np.ndarray | sparse.spmatrix], np.ndarray | sparse.spmatrix - ], + sparsity_func: Callable[[np.ndarray | _CSMatrix], np.ndarray | _CSMatrix], dtype: DTypeLike, *, small: bool, diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 833ff0d40d..5c5169fcf6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -14,7 +14,6 @@ from scipy import sparse import scanpy as sc -from scanpy._compat import DaskArray from testing.scanpy._helpers.data import pbmc68k_reduced from testing.scanpy._pytest.params import ARRAY_TYPES @@ -117,26 +116,33 @@ def test_correctness(metric, size, expected): np.testing.assert_equal(metric(adata, vals=connected), expected) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) -def test_graph_metrics_w_constant_values(metric, array_type, threading): +@pytest.mark.parametrize( + "array_type", [*ARRAY_TYPES, pytest.param(sparse.coo_matrix, id="scipy_coo")] +) +def test_graph_metrics_w_constant_values( + request: pytest.FixtureRequest, metric, array_type, threading +): + if "dask" in array_type.__name__: + reason = "DaskArray not yet supported" + request.applymarker(pytest.mark.xfail(reason=reason)) + # https://github.com/scverse/scanpy/issues/1806 pbmc = pbmc68k_reduced() - XT = array_type(pbmc.raw.X.T.copy()) + XT = pbmc.raw.X.T.copy() g = pbmc.obsp["connectivities"].copy() equality_check = partial(np.testing.assert_allclose, atol=1e-11) - if isinstance(XT, DaskArray): - pytest.skip("DaskArray yet not supported") - const_inds = np.random.choice(XT.shape[0], 10, replace=False) with warnings.catch_warnings(): warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning) XT_zero_vals = XT.copy() XT_zero_vals[const_inds, :] = 0 + XT_zero_vals = array_type(XT_zero_vals) XT_const_vals = XT.copy() XT_const_vals[const_inds, :] = 42 + XT_const_vals = array_type(XT_const_vals) - results_full = metric(g, XT) + results_full = metric(g, array_type(XT)) # TODO: Check for warnings with pytest.warns( UserWarning, match=r"10 variables were constant, will return nan for these" diff --git a/tests/test_qc_metrics.py b/tests/test_qc_metrics.py index 7ca6534b7c..9c4cb75c12 100644 --- a/tests/test_qc_metrics.py +++ b/tests/test_qc_metrics.py @@ -75,10 +75,13 @@ def test_segments_binary(): @pytest.mark.parametrize( - "cls", [np.asarray, sparse.csr_matrix, sparse.csc_matrix, sparse.coo_matrix] + "array_type", [*ARRAY_TYPES, pytest.param(sparse.coo_matrix, id="scipy_coo")] ) -def test_top_segments(cls): - a = cls(np.ones((300, 100))) +def test_top_segments(request: pytest.FixtureRequest, array_type): + if "dask" in array_type.__name__: + reason = "DaskArray not yet supported" + request.applymarker(pytest.mark.xfail(reason=reason)) + a = array_type(np.ones((300, 100))) seg = top_segment_proportions(a, [50, 100]) assert (seg[:, 0] == 0.5).all() assert (seg[:, 1] == 1.0).all() @@ -100,7 +103,7 @@ def test_qc_metrics(adata_prepared: AnnData): else adata_prepared.X ) max_X = X.max(axis=0) - if isinstance(max_X, sparse.spmatrix): + if isinstance(max_X, sparse.coo_matrix): max_X = max_X.toarray() elif isinstance(max_X, DaskArray): max_X = max_X.compute() diff --git a/tests/test_utils.py b/tests/test_utils.py index 81369a6938..76c5a525a2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -8,7 +8,7 @@ import pytest from anndata.tests.helpers import asarray from packaging.version import Version -from scipy.sparse import csr_matrix, issparse +from scipy.sparse import coo_matrix, csr_matrix, issparse from scanpy._compat import DaskArray, _legacy_numpy_gen, pkg_version from scanpy._utils import ( @@ -151,7 +151,9 @@ def test_elem_mul(array_type): np.testing.assert_array_equal(res, expd) -@pytest.mark.parametrize("array_type", ARRAY_TYPES) +@pytest.mark.parametrize( + "array_type", [*ARRAY_TYPES, pytest.param(coo_matrix, id="scipy_coo")] +) def test_axis_sum(array_type): m1 = array_type(asarray([[0, 1, 1], [1, 0, 1]])) expd_0 = np.array([1, 1, 2]) From 75c246d191aaabb02e6cce9d926de9fb40dbc77e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 17 Jan 2025 14:10:35 +0100 Subject: [PATCH 21/24] (chore): Update version number inference in dev environments (#3441) --- .gitignore | 1 - docs/release-notes/3441.dev.md | 1 + hatch.toml | 2 +- pyproject.toml | 10 +++------ src/scanpy/__init__.py | 13 +----------- src/scanpy/_version.py | 37 ++++++++++++++++++++++++++++++++++ 6 files changed, 43 insertions(+), 21 deletions(-) create mode 100644 docs/release-notes/3441.dev.md create mode 100644 src/scanpy/_version.py diff --git a/.gitignore b/.gitignore index 65f9de7e0a..de85b8a6b7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,6 @@ # Python build files __pycache__/ -/src/scanpy/_version.py /ci/scanpy-min-deps.txt /dist/ /*-env/ diff --git a/docs/release-notes/3441.dev.md b/docs/release-notes/3441.dev.md new file mode 100644 index 0000000000..5b95e1ab22 --- /dev/null +++ b/docs/release-notes/3441.dev.md @@ -0,0 +1 @@ +Fix version number inference in development environments (CI and local) {smaller}`P Angerer` diff --git a/hatch.toml b/hatch.toml index b0a1084c61..d61f91f577 100644 --- a/hatch.toml +++ b/hatch.toml @@ -15,7 +15,7 @@ scripts.clean = "git restore --source=HEAD --staged --worktree -- docs/release-n [envs.hatch-test] default-args = [ ] -features = [ "test", "dask-ml" ] +features = [ "dev", "test", "dask-ml" ] extra-dependencies = [ "ipykernel" ] overrides.matrix.deps.env-vars = [ { if = [ "pre" ], key = "UV_PRERELEASE", value = "allow" }, diff --git a/pyproject.toml b/pyproject.toml index 71f7f1c482..379f51d8c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,11 +132,9 @@ doc = [ "sam-algorithm", ] dev = [ - # getting the dev version - "setuptools_scm", - # static checking - "pre-commit", - "towncrier", + "hatch-vcs", # runtime dev version generation + "pre-commit", # static checking + "towncrier", # release note management ] # Algorithms paga = [ "igraph" ] @@ -158,8 +156,6 @@ packages = [ "src/testing", "src/scanpy" ] [tool.hatch.version] source = "vcs" raw-options.version_scheme = "release-branch-semver" -[tool.hatch.build.hooks.vcs] -version-file = "src/scanpy/_version.py" [tool.pytest.ini_options] addopts = [ diff --git a/src/scanpy/__init__.py b/src/scanpy/__init__.py index b844372d1e..d6bfc4ac4b 100644 --- a/src/scanpy/__init__.py +++ b/src/scanpy/__init__.py @@ -6,19 +6,8 @@ from packaging.version import Version -try: # See https://github.com/maresb/hatch-vcs-footgun-example - from setuptools_scm import get_version - - __version__ = get_version(root="../..", relative_to=__file__) - del get_version -except (ImportError, LookupError): - try: - from ._version import __version__ - except ModuleNotFoundError: - msg = "scanpy is not correctly installed. Please install it, e.g. with pip." - raise RuntimeError(msg) - from ._utils import check_versions +from ._version import __version__ check_versions() del check_versions diff --git a/src/scanpy/_version.py b/src/scanpy/_version.py new file mode 100644 index 0000000000..db1dda9329 --- /dev/null +++ b/src/scanpy/_version.py @@ -0,0 +1,37 @@ +"""Get version from VCS in a dev environment or from package metadata in production. + +See . +""" + +from __future__ import annotations + +from pathlib import Path + +__all__ = ["__version__"] + + +def _get_version_from_vcs() -> str: # pragma: no cover + from hatchling.metadata.core import ProjectMetadata + from hatchling.plugin.exceptions import UnknownPluginError + from hatchling.plugin.manager import PluginManager + from hatchling.utils.fs import locate_file + + if (pyproject_toml := locate_file(__file__, "pyproject.toml")) is None: + msg = "pyproject.toml not found although hatchling is installed" + raise LookupError(msg) + root = Path(pyproject_toml).parent + metadata = ProjectMetadata(root=str(root), plugin_manager=PluginManager()) + try: + # Version can be either statically set in pyproject.toml or computed dynamically: + return metadata.core.version or metadata.hatch.version.cached + except UnknownPluginError: + msg = "Unable to import hatch plugin." + raise ImportError(msg) + + +try: + __version__ = _get_version_from_vcs() +except (ImportError, LookupError): + import importlib.metadata + + __version__ = importlib.metadata.version("scanpy") From 7dfd339adf0d591e6df1e96b5f56e93ba1e3d939 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:48:30 +0100 Subject: [PATCH 22/24] [pre-commit.ci] pre-commit autoupdate (#3445) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5e0e91d8c..4d0e01d98d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.1 + rev: v0.9.2 hooks: - id: ruff args: ["--fix"] From 6c89e1dcd9452593ea40967a0bb9a06faaa38dc2 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 21 Jan 2025 10:09:58 +0100 Subject: [PATCH 23/24] Skip COO matrix tests for newer anndata (#3442) --- tests/test_qc_metrics.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_qc_metrics.py b/tests/test_qc_metrics.py index 9c4cb75c12..fd043ab0dc 100644 --- a/tests/test_qc_metrics.py +++ b/tests/test_qc_metrics.py @@ -1,10 +1,13 @@ from __future__ import annotations +from importlib.metadata import version + import numpy as np import pandas as pd import pytest from anndata import AnnData from anndata.tests.helpers import assert_equal +from packaging.version import Version from scipy import sparse import scanpy as sc @@ -18,7 +21,7 @@ ) from testing.scanpy._helpers import as_sparse_dask_array, maybe_dask_process_context from testing.scanpy._pytest.marks import needs -from testing.scanpy._pytest.params import ARRAY_TYPES +from testing.scanpy._pytest.params import ARRAY_TYPES, ARRAY_TYPES_MEM @pytest.fixture @@ -198,8 +201,18 @@ def adata_mito(): return adata_dense, init_var +skip_if_adata_0_12 = pytest.mark.skipif( + Version(version("anndata")) >= Version("0.12.0.dev0"), + reason="Newer AnnData removes implicit support for COO matrices", +) + + @pytest.mark.parametrize( - "cls", [np.asarray, sparse.csr_matrix, sparse.csc_matrix, sparse.coo_matrix] + "cls", + [ + *ARRAY_TYPES_MEM, + pytest.param(sparse.coo_matrix, marks=[skip_if_adata_0_12], id="scipy_coo"), + ], ) def test_qc_metrics_format(cls): adata_dense, init_var = adata_mito() From 8ce811ac3ab6674a7d6235f44afda3e10e366682 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 21 Jan 2025 14:37:13 +0100 Subject: [PATCH 24/24] Implement missing plot functionality for `pl.rank_genes_groups` (#3428) --- docs/release-notes/3428.bugfix.md | 1 + src/scanpy/plotting/_tools/__init__.py | 83 ++++++++++++++++---------- tests/test_plotting.py | 30 +++++++++- 3 files changed, 78 insertions(+), 36 deletions(-) create mode 100644 docs/release-notes/3428.bugfix.md diff --git a/docs/release-notes/3428.bugfix.md b/docs/release-notes/3428.bugfix.md new file mode 100644 index 0000000000..febe18a1ca --- /dev/null +++ b/docs/release-notes/3428.bugfix.md @@ -0,0 +1 @@ +Fix {func}`scanpy.pl.rank_genes_groups`’s `ax` parameter {smaller}`P Angerer` diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 8f189121e2..8c89b34fb4 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -47,6 +47,7 @@ from matplotlib.figure import Figure from ..._utils import Empty + from .._baseplot_class import BasePlot from .._utils import DensityNorm # ------------------------------------------------------------------------------ @@ -346,7 +347,7 @@ def rank_genes_groups( save: bool | None = None, ax: Axes | None = None, **kwds, -): +) -> list[Axes] | None: """\ Plot ranking of genes. @@ -370,6 +371,9 @@ def rank_genes_groups( `sharey=False`, each panel has its own y-axis range. {show_save_ax} + Returns + ------- + List of each group’s matplotlib axis or `None` if `show=True`. Examples -------- @@ -413,15 +417,26 @@ def rank_genes_groups( from matplotlib import gridspec - fig = plt.figure( - figsize=( - n_panels_x * rcParams["figure.figsize"][0], - n_panels_y * rcParams["figure.figsize"][1], + if ax is None or (sps := ax.get_subplotspec()) is None: + fig = ( + plt.figure( + figsize=( + n_panels_x * rcParams["figure.figsize"][0], + n_panels_y * rcParams["figure.figsize"][1], + ) + ) + if ax is None + else ax.get_figure() ) - ) - gs = gridspec.GridSpec(nrows=n_panels_y, ncols=n_panels_x, wspace=0.22, hspace=0.3) + gs = gridspec.GridSpec(n_panels_y, n_panels_x, fig, wspace=0.22, hspace=0.3) + else: + fig = ax.get_figure() + gs = sps.subgridspec(n_panels_y, n_panels_x) + if fig is None: + msg = "passed ax has no associated figure" + raise RuntimeError(msg) - ax0 = None + axs: list[Axes] = [] ymin = np.inf ymax = -np.inf for count, group_name in enumerate(group_names): @@ -433,20 +448,16 @@ def rank_genes_groups( ymin = min(ymin, np.min(scores)) ymax = max(ymax, np.max(scores)) - if ax0 is None: - ax = fig.add_subplot(gs[count]) - ax0 = ax - else: - ax = fig.add_subplot(gs[count], sharey=ax0) + axs.append(fig.add_subplot(gs[count], sharey=axs[0] if axs else None)) else: ymin = np.min(scores) ymax = np.max(scores) ymax += 0.3 * (ymax - ymin) - ax = fig.add_subplot(gs[count]) - ax.set_ylim(ymin, ymax) + axs.append(fig.add_subplot(gs[count])) + axs[-1].set_ylim(ymin, ymax) - ax.set_xlim(-0.9, n_genes - 0.1) + axs[-1].set_xlim(-0.9, n_genes - 0.1) # Mapping to gene_symbols if gene_symbols is not None: @@ -457,7 +468,7 @@ def rank_genes_groups( # Making labels for ig, gene_name in enumerate(gene_names): - ax.text( + axs[-1].text( ig, scores[ig], gene_name, @@ -467,23 +478,29 @@ def rank_genes_groups( fontsize=fontsize, ) - ax.set_title(f"{group_name} vs. {reference}") + axs[-1].set_title(f"{group_name} vs. {reference}") if count >= n_panels_x * (n_panels_y - 1): - ax.set_xlabel("ranking") + axs[-1].set_xlabel("ranking") # print the 'score' label only on the first panel per row. if count % n_panels_x == 0: - ax.set_ylabel("score") + axs[-1].set_ylabel("score") - if sharey is True: + if sharey is True and axs: ymax += 0.3 * (ymax - ymin) - ax.set_ylim(ymin, ymax) + axs[0].set_ylim(ymin, ymax) writekey = f"rank_genes_groups_{adata.uns[key]['params']['groupby']}" savefig_or_show(writekey, show=show, save=save) + show = settings.autoshow if show is None else show + if show: + return None + return axs -def _fig_show_save_or_axes(plot_obj, return_fig, show, save): +def _fig_show_save_or_axes( + plot_obj: BasePlot, *, return_fig: bool, show: bool | None, save: bool | None +): """ Decides what to return """ @@ -510,7 +527,7 @@ def _rank_genes_groups_plot( key: str | None = None, show: bool | None = None, save: bool | None = None, - return_fig: bool | None = False, + return_fig: bool = False, gene_symbols: str | None = None, **kwds, ): @@ -524,10 +541,6 @@ def _rank_genes_groups_plot( ) raise ValueError(msg) - if var_names is None and n_genes is None: - # set n_genes = 10 as default when none of the options is given - n_genes = 10 - if key is None: key = "rank_genes_groups" @@ -544,6 +557,10 @@ def _rank_genes_groups_plot( else: var_names_list = var_names else: + # set n_genes = 10 as default when none of the options is given + if n_genes is None: + n_genes = 10 + # dict in which each group is the key and the n_genes are the values var_names = {} var_names_list = [] @@ -621,7 +638,7 @@ def _rank_genes_groups_plot( if title is not None and "colorbar_title" not in kwds: _pl.legend(title=title) - return _fig_show_save_or_axes(_pl, return_fig, show, save) + return _fig_show_save_or_axes(_pl, return_fig=return_fig, show=show, save=save) elif plot_type == "stacked_violin": from .._stacked_violin import stacked_violin @@ -634,7 +651,7 @@ def _rank_genes_groups_plot( gene_symbols=gene_symbols, **kwds, ) - return _fig_show_save_or_axes(_pl, return_fig, show, save) + return _fig_show_save_or_axes(_pl, return_fig=return_fig, show=show, save=save) elif plot_type == "heatmap": from .._anndata import heatmap @@ -846,7 +863,7 @@ def rank_genes_groups_dotplot( key: str | None = None, show: bool | None = None, save: bool | None = None, - return_fig: bool | None = False, + return_fig: bool = False, **kwds, ): """\ @@ -985,7 +1002,7 @@ def rank_genes_groups_stacked_violin( key: str | None = None, show: bool | None = None, save: bool | None = None, - return_fig: bool | None = False, + return_fig: bool = False, **kwds, ): """\ @@ -1073,7 +1090,7 @@ def rank_genes_groups_matrixplot( key: str | None = None, show: bool | None = None, save: bool | None = None, - return_fig: bool | None = False, + return_fig: bool = False, **kwds, ): """\ diff --git a/tests/test_plotting.py b/tests/test_plotting.py index f135a68aa4..161b493823 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -28,6 +28,9 @@ if TYPE_CHECKING: from collections.abc import Callable + from matplotlib.axes import Axes + + HERE: Path = Path(__file__).parent ROOT = HERE / "_images" @@ -842,9 +845,30 @@ def test_rank_genes_groups(image_comparer, name, fn): with plt.rc_context({"axes.grid": True, "figure.figsize": (4, 4)}): fn(pbmc) - key = "ranked_genes" if name == "basic" else f"ranked_genes_{name}" - save_and_compare_images(key) - plt.close() + key = "ranked_genes" if name == "basic" else f"ranked_genes_{name}" + save_and_compare_images(key) + plt.close() + + +def test_rank_genes_group_axes(image_comparer): + fn = next(fn for name, fn in _RANK_GENES_GROUPS_PARAMS if name == "basic") + + save_and_compare_images = partial(image_comparer, ROOT, tol=23) + + pbmc = pbmc68k_reduced() + sc.tl.rank_genes_groups(pbmc, "louvain", n_genes=pbmc.raw.shape[1]) + + pbmc.var["symbol"] = pbmc.var.index + "__" + + fig, ax = plt.subplots(figsize=(12, 16)) + ax.set_axis_off() + with plt.rc_context({"axes.grid": True}): + axes: list[Axes] = fn(pbmc, ax=ax, show=False) + + assert len(axes) == 11 + fig.show() + save_and_compare_images("ranked_genes") + plt.close() @pytest.fixture(scope="session")