diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 6338a7575e..9199394032 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -671,7 +671,7 @@ def _handle_transformer( # Validate `knn` conn_method = method if method in {"gauss", None} else "umap" - if not knn and not (conn_method == "gauss" and transformer is None): + if not knn and conn_method != "gauss": # “knn=False” seems to be only intended for method “gauss” msg = f"`method = {method!r} only with `knn = True`." raise ValueError(msg) diff --git a/src/scanpy/neighbors/_backends/pairwise.py b/src/scanpy/neighbors/_backends/pairwise.py index 19f6e472dc..485a895367 100644 --- a/src/scanpy/neighbors/_backends/pairwise.py +++ b/src/scanpy/neighbors/_backends/pairwise.py @@ -44,11 +44,16 @@ def transform(self, y: _MatrixLike | None) -> _CSMatrix: from sklearn.metrics import pairwise_distances d_arr = pairwise_distances(self.x_, y, metric=self.metric, **self.metric_params) - ind, dist = _get_indices_distances_from_dense_matrix(d_arr, self.n_neighbors) + ind, dist = _get_indices_distances_from_dense_matrix( + d_arr, self.n_neighbors + 1 + ) rv = _get_sparse_matrix_from_indices_distances(ind, dist, keep_self=True) if _DEBUG: + if self.n_neighbors >= d_arr.shape[1] - 1: + np.testing.assert_equal(d_arr, rv.toarray()) + ind2, dist2 = _get_indices_distances_from_sparse_matrix( - rv, self.n_neighbors + rv, self.n_neighbors + 1 ) np.testing.assert_equal(ind, ind2) np.testing.assert_equal(dist, dist2) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 0500eb6f9f..7ac6cb567d 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -147,6 +147,8 @@ def test_distances_euclidean( [ # knn=False trivially returns all distances pytest.param(None, False, id="knn=False"), + # knn=False with pairwise returns all distances + pytest.param("sklearn-pairwise", False, id="pairwise"), # pynndescent returns all distances when data is so small pytest.param("pynndescent", True, id="pynndescent"), # Explicit brute force also returns all distances @@ -157,7 +159,13 @@ def test_distances_euclidean( ), ], ) -def test_distances_all(neigh: Neighbors, transformer, knn): +def test_distances_all( + monkeypatch: pytest.MonkeyPatch, neigh: Neighbors, transformer, knn +): + from scanpy.neighbors._backends import pairwise + + monkeypatch.setattr(pairwise, "_DEBUG", True) + neigh.compute_neighbors( n_neighbors, transformer=transformer, method="gauss", knn=knn )