-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-GPU support with dask #179
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These need docstrings. What do they do?
rapids_singlecell/src/rapids_singlecell/preprocessing/_sparse_pca/_kernels/_pca_sparse_kernel.py
Lines 68 to 77 in b6c2689
def _cov_kernel(dtype): | |
return cuda_kernel_factory(cov_kernel_str, (dtype,), "cov_kernel") | |
def _gramm_kernel_csr(dtype): | |
return cuda_kernel_factory(gramm_kernel_csr, (dtype,), "gramm_kernel_csr") | |
def _copy_kernel(dtype): | |
return cuda_kernel_factory(copy_kernel, (dtype,), "copy_kernel") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They internal functions that handle dtype
for cudakernels. They don't need docstrings.
) | ||
return gram_matrix[None, ...] # need new axis for summing | ||
|
||
n_blocks = len(x.to_delayed().ravel()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not x.blocks.size
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ilan-gold why did we do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simply did not know this existed!
There will be a seperate PR for the update of the docstrings and a tutorial. |
adata_subset = adata[adata.obs[batch_key] == batch].copy() | ||
|
||
calculate_qc_metrics(adata_subset, layer=layer) | ||
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0 | ||
adata_subset = adata_subset[:, filt] | ||
adata_subset = adata_subset[:, filt].copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why copy
here? seems like there should be a more efficient way to do this
if isinstance(X, sparse.csr_matrix): | ||
return _normalize_total_csr(X, target_sum) | ||
elif isinstance(X, DaskArray): | ||
return _normalize_total_dask(X, target_sum) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@flying-sheep just so you're away when reviewing this, this is why we can't use single dispatch anywhere: sphinx-doc/sphinx#10591
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possible that this is no longer relevant, i didn't look into it too hard
|
||
|
||
def _normalize_total(X: cp.ndarray, target_sum: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X
type is wrong here
chunks=(X.chunksize[0],), | ||
drop_axis=1, | ||
) | ||
counts_per_cell = target_sum_chunk_matrices.compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we have an implementation in scanpy that makes this lazy as well:
probably worth trying out since this is fairly expensive (i.e., requires a full-pass over the data)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll investigate if that a solution for me aswell
blocks = X.to_delayed().ravel() | ||
cell_blocks = [ | ||
da.from_delayed( | ||
__qc_calc_1(block), | ||
shape=(2, X.chunks[0][ind]), | ||
dtype=X.dtype, | ||
meta=cp.array([]), | ||
) | ||
for ind, block in enumerate(blocks) | ||
] | ||
|
||
blocks = X.to_delayed().ravel() | ||
gene_blocks = [ | ||
da.from_delayed( | ||
__qc_calc_2(block), | ||
shape=(2, X.shape[1]), | ||
dtype=X.dtype, | ||
meta=cp.array([]), | ||
) | ||
for ind, block in enumerate(blocks) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we map_blocks
here now that we are vstack-ing?
) | ||
|
||
|
||
def _first_pass_qc(X): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and needs a more descriptive name. what does it do?
) | ||
|
||
|
||
def _second_pass_qc(X, mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and a more descriptive name
|
||
|
||
@with_cupy_rmm | ||
def _second_pass_qc_dask(X, mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be heavily deduplicated and a few more comments, especially the kernels and some very short variable names.
also you should reorganize:
- why are
_pca
and_sparse_pca
sibling modules? the latter should be a submodule of the former. - why
_sparse_pca._sparse_pca
? that should probably be_sparse_pca._cupy
or_sparse_pca._mem
or so.
if isinstance(X, DaskArray): | ||
if isinstance(X._meta, cp.ndarray): | ||
X = X.map_blocks(lambda X: cp.expm1(X), meta=_meta_dense(X.dtype)) | ||
elif isinstance(X._meta, csr_matrix): | ||
X = X.map_blocks(lambda X: X.expm1(), meta=_meta_sparse(X.dtype)) | ||
else: | ||
X = cp.expm1(X) | ||
X = X.copy() | ||
if issparse(X): | ||
X = X.expm1() | ||
else: | ||
X = cp.expm1(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should probably be wrapped into something like def expm1(X: DaskArray | csr_matrix | np.ndarray):
instead of just having this inline.
also definitely needs an else branch for uncovered cases
from ._kernels._norm_kernel import _mul_csr | ||
|
||
mul_kernel = _mul_csr(X.dtype) | ||
mul_kernel.compile() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you .compile() here and not above?
Why not make these kernel .compile()
things inside? You could use functools.cache
/lru_cache
to reuse the compiled versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cupy takes care of caching and loading internally. Its just important that its there. It has massive performance implication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean not calling it results in really bad performance, whereas calling it multiple times has no noticable impact?
Then it should most likely go inside of the wrapper, (here _mul_csr
) so it’s impossible to forget at the call site
elif isinstance(X._meta, cp.ndarray): | ||
from ._kernels._norm_kernel import _mul_dense | ||
|
||
mul_kernel = _mul_dense(X.dtype) | ||
mul_kernel.compile() | ||
|
||
def __mul(X_part): | ||
mul_kernel( | ||
(math.ceil(X_part.shape[0] / 128),), | ||
(128,), | ||
(X_part, X_part.shape[0], X_part.shape[1], int(target_sum)), | ||
) | ||
return X_part | ||
|
||
X = X.map_blocks(__mul, meta=_meta_dense(X.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this branch is almost identical to the above. You should probably do:
if not isinstance(X._meta, (cp.ndarray, sparse.csr_matrix)):
raise ValueError(f"Cannot normalize {type(X)}")
...
def mul(X_part, rename_me_1, rename_me_2): ...
...
__mul = partial(mul, rename_me_1=..., rename_me_2=...) if isinstance(X._meta, cp.ndarray) else partial(mul, rename_me_1=..., rename_me_2=...)
return X.map_blocks(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really don't see how this is suppose to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
like here: #179 (comment)
identify parts that are different in both if
branches, assign them to variables, pull the parts that are identical out of the branches, where they can use the variables you define in the branches.
if all there’s left is two branches and each of the branches contain nothing but variable assignments, replace the if statement with a ternary (foo, bar = (..., ...) if condition else (..., ...)
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont really think that this is a good idea. I know that you don't like this but i think that is easier to maintain
tests/dask/test_hvg_dask.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like all these tests are copy and pasted and should be deduplicated using parametrize
tests/dask/test_normalize_dask.py
Outdated
def test_normalize_sparse(client): | ||
adata = pbmc3k() | ||
sc.pp.filter_cells(adata, min_genes=100) | ||
sc.pp.filter_genes(adata, min_cells=3) | ||
dask_data = adata.copy() | ||
dask_data.X = as_sparse_cupy_dask_array(dask_data.X) | ||
adata.X = cusparse.csr_matrix(adata.X) | ||
rsc.pp.normalize_total(adata) | ||
rsc.pp.normalize_total(dask_data) | ||
cp.testing.assert_allclose(adata.X.toarray(), dask_data.X.compute().toarray()) | ||
|
||
|
||
def test_normalize_dense(client): | ||
adata = pbmc3k() | ||
sc.pp.filter_cells(adata, min_genes=100) | ||
sc.pp.filter_genes(adata, min_cells=3) | ||
dask_data = adata.copy() | ||
dask_data.X = as_dense_cupy_dask_array(dask_data.X) | ||
adata.X = cp.array(adata.X.toarray()) | ||
rsc.pp.normalize_total(adata) | ||
rsc.pp.normalize_total(dask_data) | ||
cp.testing.assert_allclose(adata.X, dask_data.X.compute()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deduplicate using parametrize
tests/dask/test_normalize_dask.py
Outdated
def test_log1p_sparse(client): | ||
adata = pbmc3k() | ||
sc.pp.filter_cells(adata, min_genes=100) | ||
sc.pp.filter_genes(adata, min_cells=3) | ||
sc.pp.normalize_total(adata) | ||
dask_data = adata.copy() | ||
dask_data.X = as_sparse_cupy_dask_array(dask_data.X) | ||
adata.X = cusparse.csr_matrix(adata.X) | ||
rsc.pp.log1p(adata) | ||
rsc.pp.log1p(dask_data) | ||
cp.testing.assert_allclose(adata.X.toarray(), dask_data.X.compute().toarray()) | ||
|
||
|
||
def test_log1p_dense(client): | ||
adata = pbmc3k() | ||
sc.pp.filter_cells(adata, min_genes=100) | ||
sc.pp.filter_genes(adata, min_cells=3) | ||
sc.pp.normalize_total(adata) | ||
dask_data = adata.copy() | ||
dask_data.X = as_dense_cupy_dask_array(dask_data.X) | ||
adata.X = cp.array(adata.X.toarray()) | ||
rsc.pp.log1p(adata) | ||
rsc.pp.log1p(dask_data) | ||
cp.testing.assert_allclose(adata.X, dask_data.X.compute()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deduplicate using parametrize
tests/dask/test_qc_dask.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deduplicate using parametrize
tests/dask/test_scale_dask.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deduplicate using parametrize
Co-authored-by: Philipp A. <[email protected]>
for more information, see https://pre-commit.ci
This adds dask support
Functions to add:
calculate_qc_metrics
normalize_total
log1p
highly_variable_genes
withseurat
andcell_ranger
scale
PCA
neighbors