Skip to content

Commit e7a4662

Browse files
committed
_get_mean_var updated
1 parent f080c7f commit e7a4662

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

docs/release-notes/1.10.2.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@
3131
* `pp.highly_variable_genes` with `flavor=seurat_v3` now uses a numba kernel {pr}`3017` {smaller}`S Dicks`
3232
* Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer`
3333
* Speed up clipping of array in {func}`~scanpy.pp.scale` {pr}`3100` {smaller}`P Ashish & S Dicks`
34+
* Speed up _get_mean_var used in {func}`~scanpy.pp.scale` {pr}`3099` {smaller}`P Ashish & S Dicks`

src/scanpy/preprocessing/_utils.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,37 +33,45 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray:
3333
def _get_mean_var(
3434
X: _SupportedArray, *, axis: Literal[0, 1] = 0
3535
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
36-
if isinstance(X, sparse.spmatrix):
37-
mean, var = sparse_mean_variance_axis(X, axis=axis)
38-
var *= X.shape[axis] / (X.shape[axis] - 1)
36+
if isinstance(X, np.ndarray):
37+
mean, var = _compute_mean_var(X, axis=axis, dtype=np.float64)
3938
else:
40-
mean,var=_compute_mean_var(X,axis=axis,dtype=np.float64)
39+
if isinstance(X, sparse.spmatrix):
40+
mean, var = sparse_mean_variance_axis(X, axis=axis)
41+
else:
42+
mean = axis_mean(X, axis=axis, dtype=np.float64)
43+
mean_sq = axis_mean(elem_mul(X, X), axis=axis, dtype=np.float64)
44+
var = mean_sq - mean**2
45+
# enforce R convention (unbiased estimator) for variance
46+
var *= X.shape[axis] / (X.shape[axis] - 1)
4147
return mean, var
4248

43-
@numba.njit(cache=True,parallel=True)
49+
50+
@numba.njit(cache=True, parallel=True)
4451
def _compute_mean_var(
45-
X: _SupportedArray, axis: Literal[0, 1] = 0,dtype : DTypeLike | None = None
52+
X: _SupportedArray, axis: Literal[0, 1] = 0, dtype: DTypeLike | None = None
4653
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
4754
nthr = numba.get_num_threads()
48-
axis_i = 1 if axis==0 else 0
49-
s=np.zeros((nthr,X.shape[axis_i]),dtype=dtype)
50-
ss=np.zeros((nthr,X.shape[axis_i]),dtype=dtype)
51-
mean=np.zeros(X.shape[axis_i],dtype=dtype)
52-
#std=np.zeros(X.shape[axis_i],dtype=dtype)
53-
var=np.zeros(X.shape[axis_i],dtype=dtype)
55+
axis_i = 1 if axis == 0 else 0
56+
s = np.zeros((nthr, X.shape[axis_i]), dtype=dtype)
57+
ss = np.zeros((nthr, X.shape[axis_i]), dtype=dtype)
58+
mean = np.zeros(X.shape[axis_i], dtype=dtype)
59+
# std=np.zeros(X.shape[axis_i],dtype=dtype)
60+
var = np.zeros(X.shape[axis_i], dtype=dtype)
5461
n = X.shape[axis]
5562
for i in numba.prange(nthr):
56-
for r in range(i,n,nthr):
63+
for r in range(i, n, nthr):
5764
for c in range(X.shape[axis_i]):
58-
v = X[r,c] if axis==0 else X[c,r]
59-
s[i,c]+=v
60-
ss[i,c]+=v*v
65+
v = X[r, c] if axis == 0 else X[c, r]
66+
s[i, c] += v
67+
ss[i, c] += v * v
6168
for c in numba.prange(X.shape[axis_i]):
62-
s0 = s[:,c].sum()
63-
mean[c] = s0/n
64-
var[c] = (ss[:,c].sum() - s0*s0/n)/(n-1)
65-
#std[c]=np.sqrt(var[c])
66-
return mean,var
69+
s0 = s[:, c].sum()
70+
mean[c] = s0 / n
71+
var[c] = (ss[:, c].sum() - s0 * s0 / n) / (n - 1)
72+
# std[c]=np.sqrt(var[c])
73+
return mean, var
74+
6775

6876
def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int):
6977
"""

0 commit comments

Comments
 (0)