Skip to content

Refactored scanpy.tools._sparse_nanmean to eliminate unnecessary data… #3570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
acfbd11
Refactored scanpy.tools._sparse_nanmean to eliminate unnecessary data…
Reovirus Apr 8, 2025
2882948
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2025
a36b33c
rewrite logics with numba (for scipy <1.15.0)
Reovirus Apr 11, 2025
cdb443b
Add types
Reovirus Apr 11, 2025
8021f8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
523cccc
correct jit docorators
Reovirus Apr 11, 2025
968e093
add correct fuction names
Reovirus Apr 11, 2025
9d06854
rewrite logics without prange (prange tries to rewrite one element in…
Reovirus Apr 11, 2025
4087447
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
45688e7
one ptr copy
Reovirus Apr 11, 2025
e089438
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
fa3c5c2
add score_genes benchmark
flying-sheep Apr 11, 2025
cad1db9
some changes
Reovirus Apr 11, 2025
baa0959
replace np.add.at by np.bincount + add some njint
Reovirus Apr 11, 2025
6b3e891
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
2bb6d2a
Merge branch 'main' into _sparse_nanmean_is_inefficient
flying-sheep Apr 14, 2025
ffee669
Merge branch 'scverse:main' into _sparse_nanmean_is_inefficient
Reovirus Apr 29, 2025
1c3a67e
add release notes
Reovirus Apr 30, 2025
2630ee0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2025
7130f25
style
Reovirus Apr 30, 2025
7e23b19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions benchmarks/benchmarks/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,11 @@ def time_rank_genes_groups() -> None:

def peakmem_rank_genes_groups() -> None:
sc.tl.rank_genes_groups(adata, "bulk_labels", method="wilcoxon")


def time_score_genes() -> None:
sc.tl.score_genes(adata, adata.var_names[:500])


def peakmem_score_genes() -> None:
sc.tl.score_genes(adata, adata.var_names[:500])
3 changes: 3 additions & 0 deletions docs/release-notes/3570.performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Performance Enhancement: Optimized Mean Calculation in Gene Expression Matrix

Refactored the mean calculation logic in the gene expression matrix to eliminate unnecessary data copying. This optimization significantly improves execution speed, particularly beneficial for large-scale datasets.​
83 changes: 65 additions & 18 deletions src/scanpy/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import numpy as np
import pandas as pd
from numba import prange

from .. import logging as logg
from .._compat import CSBase, old_positionals
from .._compat import CSBase, CSCBase, njit, old_positionals
from .._utils import _check_use_raw, is_backed_type
from ..get import _get_obs_rep

Expand All @@ -28,28 +29,74 @@
_GetSubset = Callable[[_StrIdx], np.ndarray | CSBase]


@njit
def _get_sparce_nanmean_indptr(
data: NDArray[np.float64], indptr: NDArray[np.int32], shape: tuple[int, int]
) -> NDArray[np.float64]:
n_rows = len(indptr) - 1
result = np.empty(n_rows, dtype=np.float64)

for i in prange(n_rows):
start = indptr[i]
end = indptr[i + 1]
count = np.float64(shape[1])
total = 0.0
for j in prange(start, end):
val = data[j]
if not np.isnan(val):
total += val
else:
count -= 1
if count == 0:
result[i] = np.nan
else:
result[i] = total / count
return result


@njit
def _get_sparce_nanmean_indices(
data: NDArray[np.float64], indices: NDArray[np.int32], shape: tuple
) -> NDArray[np.float64]:
num_bins = shape[1]
num_elements = np.float64(shape[0])
sum_arr = np.zeros(num_bins, dtype=np.float64)
count_arr = np.repeat(num_elements, num_bins)
result = np.zeros(num_bins, dtype=np.float64)

for i in range(data.size):
idx = indices[i]
val = data[i]
if not np.isnan(val):
sum_arr[idx] += val
else:
count_arr[idx] -= 1.0

for i in range(num_bins):
if count_arr[i] == 0:
result[i] = np.nan
else:
result[i] = sum_arr[i] / count_arr[i]
return result


def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:
"""np.nanmean equivalent for sparse matrices."""
if not isinstance(X, CSBase):
msg = "X must be a compressed sparse matrix"
raise TypeError(msg)

# count the number of nan elements per row/column (dep. on axis)
Z = X.copy()
Z.data = np.isnan(Z.data)
Z.eliminate_zeros()
n_elements = Z.shape[axis] - Z.sum(axis)

# set the nans to 0, so that a normal .sum() works
Y = X.copy()
Y.data[np.isnan(Y.data)] = 0
Y.eliminate_zeros()

# the average
s = Y.sum(axis, dtype="float64") # float64 for score_genes function compatibility)
m = s / n_elements

return m
algo_shape = X.shape
algo_axis = axis
# in CSC ans CSR we have "transposed" form of data storaging (indices is colums/rows, indptr is row/columns)
# as a result, algorythm for CSC is algorythm for CSR but with transposed shape (columns in CSC is equal rows in CSR)
# base algo for CSR, for csc we should "transpose" matrix size and use same logics
if isinstance(X, CSCBase):
algo_shape = X.shape[::-1]
algo_axis = int(not axis)
if algo_axis == 1:
return _get_sparce_nanmean_indptr(X.data, X.indptr, algo_shape)
else:
return _get_sparce_nanmean_indices(X.data, X.indices, algo_shape)


@old_positionals(
Expand Down
Loading