Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU support with dask #179

Open
wants to merge 83 commits into
base: main
Choose a base branch
from
Open

Multi-GPU support with dask #179

wants to merge 83 commits into from

Conversation

Intron7
Copy link
Member

@Intron7 Intron7 commented Apr 25, 2024

This adds dask support

Functions to add:

  • calculate_qc_metrics
  • normalize_total
  • log1p
  • highly_variable_genes with seurat and cell_ranger
  • scale
  • PCA
  • neighbors

@Intron7 Intron7 marked this pull request as draft April 25, 2024 13:08
@Intron7 Intron7 changed the title add first functions Multi-GPU support with dask Apr 30, 2024
@Intron7 Intron7 added the run-gpu-ci runs GPU CI label May 3, 2024
@Intron7 Intron7 marked this pull request as ready for review May 13, 2024 14:27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These need docstrings. What do they do?

def _cov_kernel(dtype):
return cuda_kernel_factory(cov_kernel_str, (dtype,), "cov_kernel")
def _gramm_kernel_csr(dtype):
return cuda_kernel_factory(gramm_kernel_csr, (dtype,), "gramm_kernel_csr")
def _copy_kernel(dtype):
return cuda_kernel_factory(copy_kernel, (dtype,), "copy_kernel")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They internal functions that handle dtype for cudakernels. They don't need docstrings.

)
return gram_matrix[None, ...] # need new axis for summing

n_blocks = len(x.to_delayed().ravel())
Copy link
Member

@flying-sheep flying-sheep Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not x.blocks.size?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilan-gold why did we do this?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simply did not know this existed!

@Intron7 Intron7 marked this pull request as ready for review October 1, 2024 08:48
@Intron7
Copy link
Member Author

Intron7 commented Oct 1, 2024

There will be a seperate PR for the update of the docstrings and a tutorial.

Comment on lines +432 to +436
adata_subset = adata[adata.obs[batch_key] == batch].copy()

calculate_qc_metrics(adata_subset, layer=layer)
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0
adata_subset = adata_subset[:, filt]
adata_subset = adata_subset[:, filt].copy()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why copy here? seems like there should be a more efficient way to do this

Comment on lines +85 to +88
if isinstance(X, sparse.csr_matrix):
return _normalize_total_csr(X, target_sum)
elif isinstance(X, DaskArray):
return _normalize_total_dask(X, target_sum)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@flying-sheep just so you're away when reviewing this, this is why we can't use single dispatch anywhere: sphinx-doc/sphinx#10591

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible that this is no longer relevant, i didn't look into it too hard



def _normalize_total(X: cp.ndarray, target_sum: int):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X type is wrong here

src/rapids_singlecell/preprocessing/_normalize.py Outdated Show resolved Hide resolved
chunks=(X.chunksize[0],),
drop_axis=1,
)
counts_per_cell = target_sum_chunk_matrices.compute()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we have an implementation in scanpy that makes this lazy as well:

https://github.com/scverse/scanpy/blob/be99b230fa84e077f5167979bc9f6dacc4ad0d41/src/scanpy/preprocessing/_normalization.py#L34-L48

probably worth trying out since this is fairly expensive (i.e., requires a full-pass over the data)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll investigate if that a solution for me aswell

Comment on lines 304 to 324
blocks = X.to_delayed().ravel()
cell_blocks = [
da.from_delayed(
__qc_calc_1(block),
shape=(2, X.chunks[0][ind]),
dtype=X.dtype,
meta=cp.array([]),
)
for ind, block in enumerate(blocks)
]

blocks = X.to_delayed().ravel()
gene_blocks = [
da.from_delayed(
__qc_calc_2(block),
shape=(2, X.shape[1]),
dtype=X.dtype,
meta=cp.array([]),
)
for ind, block in enumerate(blocks)
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we map_blocks here now that we are vstack-ing?

)


def _first_pass_qc(X):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and needs a more descriptive name. what does it do?

)


def _second_pass_qc(X, mask):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and a more descriptive name



@with_cupy_rmm
def _second_pass_qc_dask(X, mask):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

types

Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be heavily deduplicated and a few more comments, especially the kernels and some very short variable names.

also you should reorganize:

  • why are _pca and _sparse_pca sibling modules? the latter should be a submodule of the former.
  • why _sparse_pca._sparse_pca? that should probably be _sparse_pca._cupy or _sparse_pca._mem or so.

src/rapids_singlecell/_compat.py Outdated Show resolved Hide resolved
src/rapids_singlecell/_compat.py Outdated Show resolved Hide resolved
Comment on lines +291 to +301
if isinstance(X, DaskArray):
if isinstance(X._meta, cp.ndarray):
X = X.map_blocks(lambda X: cp.expm1(X), meta=_meta_dense(X.dtype))
elif isinstance(X._meta, csr_matrix):
X = X.map_blocks(lambda X: X.expm1(), meta=_meta_sparse(X.dtype))
else:
X = cp.expm1(X)
X = X.copy()
if issparse(X):
X = X.expm1()
else:
X = cp.expm1(X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably be wrapped into something like def expm1(X: DaskArray | csr_matrix | np.ndarray): instead of just having this inline.

also definitely needs an else branch for uncovered cases

from ._kernels._norm_kernel import _mul_csr

mul_kernel = _mul_csr(X.dtype)
mul_kernel.compile()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you .compile() here and not above?

Why not make these kernel .compile() things inside? You could use functools.cache/lru_cache to reuse the compiled versions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cupy takes care of caching and loading internally. Its just important that its there. It has massive performance implication.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean not calling it results in really bad performance, whereas calling it multiple times has no noticable impact?

Then it should most likely go inside of the wrapper, (here _mul_csr) so it’s impossible to forget at the call site

Comment on lines +131 to +145
elif isinstance(X._meta, cp.ndarray):
from ._kernels._norm_kernel import _mul_dense

mul_kernel = _mul_dense(X.dtype)
mul_kernel.compile()

def __mul(X_part):
mul_kernel(
(math.ceil(X_part.shape[0] / 128),),
(128,),
(X_part, X_part.shape[0], X_part.shape[1], int(target_sum)),
)
return X_part

X = X.map_blocks(__mul, meta=_meta_dense(X.dtype))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this branch is almost identical to the above. You should probably do:

if not isinstance(X._meta, (cp.ndarray, sparse.csr_matrix)):
    raise ValueError(f"Cannot normalize {type(X)}")
...
def mul(X_part, rename_me_1, rename_me_2): ...
...
__mul = partial(mul, rename_me_1=..., rename_me_2=...) if isinstance(X._meta, cp.ndarray) else partial(mul, rename_me_1=..., rename_me_2=...)
return X.map_blocks(...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't see how this is suppose to work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like here: #179 (comment)

identify parts that are different in both if branches, assign them to variables, pull the parts that are identical out of the branches, where they can use the variables you define in the branches.

if all there’s left is two branches and each of the branches contain nothing but variable assignments, replace the if statement with a ternary (foo, bar = (..., ...) if condition else (..., ...))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont really think that this is a good idea. I know that you don't like this but i think that is easier to maintain

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like all these tests are copy and pasted and should be deduplicated using parametrize

Comment on lines 15 to 37
def test_normalize_sparse(client):
adata = pbmc3k()
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)
dask_data = adata.copy()
dask_data.X = as_sparse_cupy_dask_array(dask_data.X)
adata.X = cusparse.csr_matrix(adata.X)
rsc.pp.normalize_total(adata)
rsc.pp.normalize_total(dask_data)
cp.testing.assert_allclose(adata.X.toarray(), dask_data.X.compute().toarray())


def test_normalize_dense(client):
adata = pbmc3k()
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)
dask_data = adata.copy()
dask_data.X = as_dense_cupy_dask_array(dask_data.X)
adata.X = cp.array(adata.X.toarray())
rsc.pp.normalize_total(adata)
rsc.pp.normalize_total(dask_data)
cp.testing.assert_allclose(adata.X, dask_data.X.compute())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deduplicate using parametrize

Comment on lines 39 to 62
def test_log1p_sparse(client):
adata = pbmc3k()
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata)
dask_data = adata.copy()
dask_data.X = as_sparse_cupy_dask_array(dask_data.X)
adata.X = cusparse.csr_matrix(adata.X)
rsc.pp.log1p(adata)
rsc.pp.log1p(dask_data)
cp.testing.assert_allclose(adata.X.toarray(), dask_data.X.compute().toarray())


def test_log1p_dense(client):
adata = pbmc3k()
sc.pp.filter_cells(adata, min_genes=100)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata)
dask_data = adata.copy()
dask_data.X = as_dense_cupy_dask_array(dask_data.X)
adata.X = cp.array(adata.X.toarray())
rsc.pp.log1p(adata)
rsc.pp.log1p(dask_data)
cp.testing.assert_allclose(adata.X, dask_data.X.compute())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deduplicate using parametrize

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deduplicate using parametrize

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deduplicate using parametrize

@github-actions github-actions bot removed the run-gpu-ci runs GPU CI label Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants