Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Jan 23, 2025
1 parent 59d40ae commit 51a9e22
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def _handle_transformer(

# Validate `knn`
conn_method = method if method in {"gauss", None} else "umap"
if not knn and not (conn_method == "gauss" and transformer is None):
if not knn and conn_method != "gauss":
# “knn=False” seems to be only intended for method “gauss”
msg = f"`method = {method!r} only with `knn = True`."
raise ValueError(msg)
Expand Down
9 changes: 7 additions & 2 deletions src/scanpy/neighbors/_backends/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ def transform(self, y: _MatrixLike | None) -> _CSMatrix:
from sklearn.metrics import pairwise_distances

d_arr = pairwise_distances(self.x_, y, metric=self.metric, **self.metric_params)
ind, dist = _get_indices_distances_from_dense_matrix(d_arr, self.n_neighbors)
ind, dist = _get_indices_distances_from_dense_matrix(
d_arr, self.n_neighbors + 1
)
rv = _get_sparse_matrix_from_indices_distances(ind, dist, keep_self=True)
if _DEBUG:
if self.n_neighbors >= d_arr.shape[1] - 1:
np.testing.assert_equal(d_arr, rv.toarray())

ind2, dist2 = _get_indices_distances_from_sparse_matrix(
rv, self.n_neighbors
rv, self.n_neighbors + 1
)
np.testing.assert_equal(ind, ind2)
np.testing.assert_equal(dist, dist2)
Expand Down
10 changes: 9 additions & 1 deletion tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def test_distances_euclidean(
[
# knn=False trivially returns all distances
pytest.param(None, False, id="knn=False"),
# knn=False with pairwise returns all distances
pytest.param("sklearn-pairwise", False, id="pairwise"),
# pynndescent returns all distances when data is so small
pytest.param("pynndescent", True, id="pynndescent"),
# Explicit brute force also returns all distances
Expand All @@ -157,7 +159,13 @@ def test_distances_euclidean(
),
],
)
def test_distances_all(neigh: Neighbors, transformer, knn):
def test_distances_all(
monkeypatch: pytest.MonkeyPatch, neigh: Neighbors, transformer, knn
):
from scanpy.neighbors._backends import pairwise

monkeypatch.setattr(pairwise, "_DEBUG", True)

neigh.compute_neighbors(
n_neighbors, transformer=transformer, method="gauss", knn=knn
)
Expand Down

0 comments on commit 51a9e22

Please sign in to comment.