Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix neighbors transformer shortcut #3444

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .._settings import settings
from .._utils import NeighborsView, _doc_params, get_literal_vals
from . import _connectivity
from ._backends.pairwise import PairwiseDistancesTransformer
from ._common import (
_get_indices_distances_from_sparse_matrix,
_get_sparse_matrix_from_indices_distances,
Expand Down Expand Up @@ -92,6 +93,12 @@ def neighbors(
connectivities are computed according to :cite:t:`Coifman2005`, in the adaption of
:cite:t:`Haghverdi2016`.

.. note::

Since scanpy 1.10, the results changed slightly.
We recommend to ensure reproducibility by pinning all package versions,
but you can get the old results by specifying `transformer='sklearn-pairwise'`.

Parameters
----------
adata
Expand Down Expand Up @@ -122,10 +129,13 @@ def neighbors(
See :doc:`/how-to/knn-transformers` for more details.
Also accepts the following known options:

`None` (the default)
`None` | `'sklearn'` (the default)
Behavior depends on data size.
For small data, we will calculate exact kNN, otherwise we use
:class:`~pynndescent.pynndescent_.PyNNDescentTransformer`
`'sklearn-pairwise'`
For compatibility with scanpy <1.10, this allows to use
:class:`~sklearn.metrics.pairwise_distances`.
`'pynndescent'`
:class:`~pynndescent.pynndescent_.PyNNDescentTransformer`
`'rapids'`
Expand Down Expand Up @@ -642,7 +652,7 @@ def _handle_transformer(
use_dense_distances = (
kwds["metric"] == "euclidean" and self._adata.n_obs < 8192
) or not knn
shortcut = transformer == "sklearn" or (
shortcut = transformer in {"sklearn", "sklearn-pairwise"} or (
transformer is None and (use_dense_distances or self._adata.n_obs < 4096)
)

Expand All @@ -661,7 +671,7 @@ def _handle_transformer(

# Validate `knn`
conn_method = method if method in {"gauss", None} else "umap"
if not knn and not (conn_method == "gauss" and transformer is None):
if not knn and conn_method != "gauss":
# “knn=False” seems to be only intended for method “gauss”
msg = f"`method = {method!r} only with `knn = True`."
raise ValueError(msg)
Expand All @@ -670,11 +680,19 @@ def _handle_transformer(
if shortcut:
from sklearn.neighbors import KNeighborsTransformer

assert transformer in {None, "sklearn"}
assert transformer in {None, "sklearn", "sklearn-pairwise"}
n_neighbors = self._adata.n_obs - 1
if knn: # only obey n_neighbors arg if knn set
n_neighbors = min(n_neighbors, kwds["n_neighbors"])
transformer = KNeighborsTransformer(

# sklearn-pairwise is opt-in, because it takes more memory
transformer_cls = (
PairwiseDistancesTransformer
if transformer == "sklearn-pairwise"
else KNeighborsTransformer
)

transformer = transformer_cls(
algorithm="brute",
n_jobs=settings.n_jobs,
n_neighbors=n_neighbors,
Expand Down
60 changes: 60 additions & 0 deletions src/scanpy/neighbors/_backends/pairwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from dataclasses import KW_ONLY, dataclass
from typing import TYPE_CHECKING

import numpy as np
from sklearn.base import TransformerMixin

from .._common import (
_get_indices_distances_from_dense_matrix,
_get_indices_distances_from_sparse_matrix,
_get_sparse_matrix_from_indices_distances,
)

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Literal, Self

from numpy.typing import NDArray

from ..._utils import _CSMatrix

_Metric = Literal["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"]
_MatrixLike = NDArray | _CSMatrix


_DEBUG = False


@dataclass
class PairwiseDistancesTransformer(TransformerMixin):
_: KW_ONLY
algorithm: Literal["brute"]
n_jobs: int
n_neighbors: int
metric: _Metric
metric_params: Mapping[str, object]

def fit(self, x: _MatrixLike) -> Self:
self.x_ = x
return self

def transform(self, y: _MatrixLike | None) -> _CSMatrix:
from sklearn.metrics import pairwise_distances

d_arr = pairwise_distances(self.x_, y, metric=self.metric, **self.metric_params)
ind, dist = _get_indices_distances_from_dense_matrix(
d_arr, self.n_neighbors + 1
)
rv = _get_sparse_matrix_from_indices_distances(ind, dist, keep_self=True)
if _DEBUG:
if self.n_neighbors >= d_arr.shape[1] - 1:
np.testing.assert_equal(d_arr, rv.toarray())

ind2, dist2 = _get_indices_distances_from_sparse_matrix(
rv, self.n_neighbors + 1
)
np.testing.assert_equal(ind, ind2)
np.testing.assert_equal(dist, dist2)
return rv
2 changes: 1 addition & 1 deletion src/scanpy/neighbors/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# These two are used with get_literal_vals elsewhere
_Method = Literal["umap", "gauss"]
_KnownTransformer = Literal["pynndescent", "sklearn", "rapids"]
_KnownTransformer = Literal["pynndescent", "sklearn", "sklearn-pairwise", "rapids"]

# sphinx-autodoc-typehints can’t transitively import types from if TYPE_CHECKING blocks,
# so these four needs to be importable
Expand Down
15 changes: 12 additions & 3 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pytest_mock import MockerFixture


# the input data
X = [[1, 0], [3, 0], [5, 6], [0, 4]]
n_neighbors = 3 # includes data points themselves
Expand Down Expand Up @@ -146,6 +147,8 @@ def test_distances_euclidean(
[
# knn=False trivially returns all distances
pytest.param(None, False, id="knn=False"),
# knn=False with pairwise returns all distances
pytest.param("sklearn-pairwise", False, id="pairwise"),
# pynndescent returns all distances when data is so small
pytest.param("pynndescent", True, id="pynndescent"),
# Explicit brute force also returns all distances
Expand All @@ -156,7 +159,13 @@ def test_distances_euclidean(
),
],
)
def test_distances_all(neigh: Neighbors, transformer, knn):
def test_distances_all(
monkeypatch: pytest.MonkeyPatch, neigh: Neighbors, transformer, knn
):
from scanpy.neighbors._backends import pairwise

monkeypatch.setattr(pairwise, "_DEBUG", True)

neigh.compute_neighbors(
n_neighbors, transformer=transformer, method="gauss", knn=knn
)
Expand Down Expand Up @@ -191,7 +200,7 @@ def test_connectivities_euclidean(neigh: Neighbors, method, conn, trans, trans_s
np.testing.assert_allclose(neigh.transitions.toarray(), trans, rtol=1e-5)


def test_gauss_noknn_connectivities_euclidean(neigh):
def test_gauss_noknn_connectivities_euclidean(neigh: Neighbors):
neigh.compute_neighbors(n_neighbors, method="gauss", knn=False)
np.testing.assert_allclose(neigh.connectivities, connectivities_gauss_noknn)
neigh.compute_transitions()
Expand Down Expand Up @@ -226,7 +235,7 @@ def test_use_rep_argument():


@pytest.mark.parametrize("conv", [csr_matrix.toarray, csr_matrix])
def test_restore_n_neighbors(neigh, conv):
def test_restore_n_neighbors(neigh: Neighbors, conv):
neigh.compute_neighbors(n_neighbors, method="gauss")

ad = AnnData(np.array(X))
Expand Down
Loading