Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Intron7 committed May 3, 2024
1 parent 177afa1 commit dd1377c
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 25 deletions.
9 changes: 4 additions & 5 deletions src/rapids_singlecell/preprocessing/_hvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,12 @@ def _highly_variable_genes_single_batch(
X = _get_obs_rep(adata, layer=layer)
_check_gpu_X(X, allow_dask=True)
if hasattr(X, "_view_args"): # AnnData array view
# For compatibility with anndata<0.9
X = X.copy() # Doesn't actually copy memory, just removes View class wrapper
X = X.copy()

if flavor == "seurat":
if isinstance(X, DaskArray):
if isinstance(X._meta, cp.ndarray):
X = X.map_blocks(cp.expm1, meta=_meta_dense(X.dtype))
X = X.map_blocks(lambda X: cp.expm1(X), meta=_meta_dense(X.dtype))
elif isinstance(X._meta, csr_matrix):
X = X.map_blocks(lambda X: X.expm1(), meta=_meta_sparse(X.dtype))
else:
Expand Down Expand Up @@ -433,11 +432,11 @@ def _highly_variable_genes_batched(
dfs = []
gene_list = adata.var_names
for batch in batches:
adata_subset = adata[adata.obs[batch_key] == batch]
adata_subset = adata[adata.obs[batch_key] == batch].copy()

calculate_qc_metrics(adata_subset, layer=layer, client=client)
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0
adata_subset = adata_subset[:, filt]
adata_subset = adata_subset[:, filt].copy()

hvg = _highly_variable_genes_single_batch(
adata_subset,
Expand Down
4 changes: 2 additions & 2 deletions src/rapids_singlecell/preprocessing/_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __mul(X_part):
return X_part

X = X.map_blocks(lambda X: __mul(X), meta=_meta_sparse(X.dtype))
elif isinstance(X.meta, cp.ndarray):
elif isinstance(X._meta, cp.ndarray):
from ._kernels._norm_kernel import _mul_dense

mul_kernel = _mul_dense(X.dtype)
Expand Down Expand Up @@ -269,7 +269,7 @@ def log1p(
X = X.log1p()
elif isinstance(X, DaskArray):
if isinstance(X._meta, cp.ndarray):
X = X.map_blocks(cp.log1p, meta=_meta_dense(X.dtype))
X = X.map_blocks(lambda X: cp.log1p(X), meta=_meta_dense(X.dtype))
elif isinstance(X._meta, sparse.csr_matrix):
X = X.map_blocks(lambda X: X.log1p(), meta=_meta_sparse(X.dtype))
adata.uns["log1p"] = {"base": None}
Expand Down
23 changes: 17 additions & 6 deletions src/rapids_singlecell/preprocessing/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,17 @@ def pca(

if svd_solver == "auto":
svd_solver = "jacobi"
pca_func = PCA(n_components=n_comps, svd_solver=svd_solver, client=client)
pca_func = PCA(
n_components=n_comps, svd_solver=svd_solver, whiten=False, client=client
)
X_pca = pca_func.fit_transform(X)
X_pca = X_pca.compute_chunk_sizes()
elif isinstance(X._meta, csr_matrix):
from ._sparse_pca._dask_sparse_pca import PCA_sparse_dask

pca_func = PCA_sparse_dask(n_components=n_comps, client=client)
X_pca = pca_func.fit_transform(X)
pca_func = pca_func.fit(X)
X_pca = pca_func.transform(X)

else:
if chunked:
Expand Down Expand Up @@ -213,16 +217,23 @@ def pca(
"use_highly_variable": mask_var_param == "highly_variable",
"mask_var": mask_var_param,
},
"variance": pca_func.explained_variance_,
"variance_ratio": pca_func.explained_variance_ratio_,
"variance": _as_numpy(pca_func.explained_variance_),
"variance_ratio": _as_numpy(pca_func.explained_variance_ratio_),
}
adata.uns["pca"] = uns_entry
if layer is not None:
adata.uns["pca"]["params"]["layer"] = layer
if mask_var is not None:
adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm["PCs"][mask_var] = pca_func.components_.T
adata.varm["PCs"][mask_var] = _as_numpy(pca_func.components_.T)
else:
adata.varm["PCs"] = pca_func.components_.T
adata.varm["PCs"] = _as_numpy(pca_func.components_.T)
if copy:
return adata


def _as_numpy(X):
if isinstance(X, cp.ndarray):
return X.get()
else:
return X
12 changes: 5 additions & 7 deletions src/rapids_singlecell/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def __qc_calc(X_part):
elif isinstance(X._meta, cp.ndarray):
from ._kernels._qc_kernels import _sparse_qc_dense

_sparse_qc_dense = _sparse_qc_dense(X.dtype)
_sparse_qc_dense.compile()
sparse_qc_dense = _sparse_qc_dense(X.dtype)
sparse_qc_dense.compile()

def __qc_calc(X_part):
sums_cells = cp.zeros(X_part.shape[0], dtype=X_part.dtype)
Expand All @@ -238,7 +238,6 @@ def __qc_calc(X_part):
int(math.ceil(X_part.shape[0] / block[0])),
int(math.ceil(X_part.shape[1] / block[1])),
)
sparse_qc_dense = _sparse_qc_dense(X.dtype)
sparse_qc_dense(
grid,
block,
Expand Down Expand Up @@ -364,10 +363,10 @@ def __qc_calc(X_part):
return sums_cells_sub

elif isinstance(X._meta, cp.ndarray):
from ._kernels._qc_kernels import _sparse_qc_dense
from ._kernels._qc_kernels import _sparse_qc_dense_sub

_sparse_qc_dense = _sparse_qc_dense(X.dtype)
_sparse_qc_dense.compile()
sparse_qc_dense = _sparse_qc_dense_sub(X.dtype)
sparse_qc_dense.compile()

def __qc_calc(X_part):
sums_cells_sub = cp.zeros((X_part.shape[0]), dtype=X_part.dtype)
Expand All @@ -378,7 +377,6 @@ def __qc_calc(X_part):
int(math.ceil(X_part.shape[0] / block[0])),
int(math.ceil(X_part.shape[1] / block[1])),
)
sparse_qc_dense = _sparse_qc_dense(X.dtype)
sparse_qc_dense(
grid,
block,
Expand Down
64 changes: 59 additions & 5 deletions src/rapids_singlecell/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,53 @@ def __mean_var(X_part, minor, major):
return mean, var


@with_cupy_rmm
def _mean_var_dense_dask(X, axis, client=None):
"""
Implements sum operation for dask array when the backend is cupy sparse csr matrix
"""
import dask.array as da

client = _get_dask_client(client)

def __mean_var(X_part, axis):
mean = X_part.sum(axis=axis)
var = (X_part**2).sum(axis=axis)
if axis == 0:
mean = mean.reshape(-1, 1)
var = var.reshape(-1, 1)
return mean, var

parts = client.sync(_extract_partitions, X)
futures = [client.submit(__mean_var, part, axis, workers=[w]) for w, part in parts]
# Gather results from futures
results = client.gather(futures)

# Initialize lists to hold the Dask arrays
means_objs = []
var_objs = []

# Process each result
for means, vars_ in results:
# Append the arrays to their respective lists as Dask arrays
means_objs.append(da.from_array(means, chunks=means.shape))
var_objs.append(da.from_array(vars_, chunks=vars_.shape))
if axis == 0:
mean = da.concatenate(means_objs, axis=1).sum(axis=1)
var = da.concatenate(var_objs, axis=1).sum(axis=1)
else:
mean = da.concatenate(means_objs)
var = da.concatenate(var_objs)

mean, var = da.compute(mean, var)
mean, var = mean.ravel(), var.ravel()
mean = mean / X.shape[axis]
var = var / X.shape[axis]
var -= cp.power(mean, 2)
var *= X.shape[axis] / (X.shape[axis] - 1)
return mean, var


def _get_mean_var(X, axis=0, client=None):
if issparse(X):
if axis == 0:
Expand Down Expand Up @@ -191,7 +238,8 @@ def _get_mean_var(X, axis=0, client=None):
major = X.shape[0]
minor = X.shape[1]
mean, var = _mean_var_major_dask(X, major, minor, client)

elif isinstance(X._meta, cp.ndarray):
mean, var = _mean_var_dense_dask(X, axis, client)
else:
mean = X.mean(axis=axis)
var = X.var(axis=axis)
Expand All @@ -215,7 +263,16 @@ def _check_nonnegative_integers(X):


def _check_gpu_X(X, require_cf=False, allow_dask=False):
if isinstance(X, cp.ndarray):
if isinstance(X, DaskArray):
if allow_dask:
return _check_gpu_X(X._meta)
else:
raise TypeError(
"The input is a DaskArray. "
"Rapids-singlecell doesn't support DaskArray in this function, "
"so your input must be a CuPy ndarray or a CuPy sparse matrix. "
)
elif isinstance(X, cp.ndarray):
return True
elif issparse(X):
if not require_cf:
Expand All @@ -225,9 +282,6 @@ def _check_gpu_X(X, require_cf=False, allow_dask=False):
else:
X.sort_indices()
X.sum_duplicates()
elif allow_dask:
if isinstance(X, DaskArray):
return _check_gpu_X(X._meta)
else:
raise TypeError(
"The input is not a CuPy ndarray or CuPy sparse matrix. "
Expand Down
48 changes: 48 additions & 0 deletions tests/dask/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import cupy as cp
from cupyx.scipy import sparse as cusparse
from anndata.tests.helpers import as_dense_dask_array, as_sparse_dask_array
import pytest

from dask_cuda import LocalCUDACluster
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
from dask.distributed import Client


def as_sparse_cupy_dask_array(X):
da = as_sparse_dask_array(X)
da = da.rechunk((da.shape[0]//2, da.shape[1]))
da = da.map_blocks(cusparse.csr_matrix, dtype = X.dtype)
return da

def as_dense_cupy_dask_array(X):
X = as_dense_dask_array(X)
X = X.map_blocks(cp.array)
X = X.rechunk((X.shape[0]//2, X.shape[1]))
return X

from dask_cuda import initialize
from dask_cuda import LocalCUDACluster
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny
from dask.distributed import Client

@pytest.fixture(scope="module")
def cluster():

cluster = LocalCUDACluster(
CUDA_VISIBLE_DEVICES ="0",
protocol="tcp",
scheduler_port=0,
worker_class=IncreasedCloseTimeoutNanny,
)
yield cluster
cluster.close()


@pytest.fixture(scope="function")
def client(cluster):

client = Client(cluster)
yield client
client.close()
62 changes: 62 additions & 0 deletions tests/dask/test_dask_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import cupy as cp
import numpy as np
from cupyx.scipy import sparse as cusparse
from scipy import sparse
from conftest import as_dense_cupy_dask_array, as_sparse_cupy_dask_array
import rapids_singlecell as rsc

from scanpy.datasets import pbmc3k_processed

def test_pca_sparse_dask(client):
sparse_ad = pbmc3k_processed()
default = pbmc3k_processed()
sparse_ad.X = sparse.csr_matrix(sparse_ad.X.astype(np.float64))
default.X = as_sparse_cupy_dask_array(default.X.astype(np.float64))
rsc.pp.pca(sparse_ad)
rsc.pp.pca(default)

cp.testing.assert_allclose(
np.abs(sparse_ad.obsm["X_pca"]),
cp.abs(default.obsm["X_pca"].compute()),
rtol=1e-7,
atol=1e-6,
)

cp.testing.assert_allclose(
np.abs(sparse_ad.varm["PCs"]), np.abs(default.varm["PCs"]), rtol=1e-7, atol=1e-6
)

cp.testing.assert_allclose(
np.abs(sparse_ad.uns["pca"]["variance_ratio"]),
np.abs(default.uns["pca"]["variance_ratio"]),
rtol=1e-7,
atol=1e-6,
)

def test_pca_dense_dask(client):
sparse_ad = pbmc3k_processed()
default = pbmc3k_processed()
sparse_ad.X = cp.array(sparse_ad.X.astype(np.float64))
default.X = as_dense_cupy_dask_array(default.X.astype(np.float64))
rsc.pp.pca(sparse_ad, svd_solver="full")
rsc.pp.pca(default, svd_solver="full")

cp.testing.assert_allclose(
np.abs(sparse_ad.obsm["X_pca"]),
cp.abs(default.obsm["X_pca"].compute()),
rtol=1e-7,
atol=1e-6,
)

cp.testing.assert_allclose(
np.abs(sparse_ad.varm["PCs"]), np.abs(default.varm["PCs"]), rtol=1e-7, atol=1e-6
)

cp.testing.assert_allclose(
np.abs(sparse_ad.uns["pca"]["variance_ratio"]),
np.abs(default.uns["pca"]["variance_ratio"]),
rtol=1e-7,
atol=1e-6,
)
Loading

0 comments on commit dd1377c

Please sign in to comment.