Skip to content

Commit

Permalink
_get_mean_var updated
Browse files Browse the repository at this point in the history
  • Loading branch information
ashish615 committed Jun 18, 2024
1 parent f080c7f commit e7a4662
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/1.10.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
* `pp.highly_variable_genes` with `flavor=seurat_v3` now uses a numba kernel {pr}`3017` {smaller}`S Dicks`
* Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer`
* Speed up clipping of array in {func}`~scanpy.pp.scale` {pr}`3100` {smaller}`P Ashish & S Dicks`
* Speed up _get_mean_var used in {func}`~scanpy.pp.scale` {pr}`3099` {smaller}`P Ashish & S Dicks`
50 changes: 29 additions & 21 deletions src/scanpy/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,45 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray:
def _get_mean_var(
X: _SupportedArray, *, axis: Literal[0, 1] = 0
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
if isinstance(X, sparse.spmatrix):
mean, var = sparse_mean_variance_axis(X, axis=axis)
var *= X.shape[axis] / (X.shape[axis] - 1)
if isinstance(X, np.ndarray):
mean, var = _compute_mean_var(X, axis=axis, dtype=np.float64)
else:
mean,var=_compute_mean_var(X,axis=axis,dtype=np.float64)
if isinstance(X, sparse.spmatrix):
mean, var = sparse_mean_variance_axis(X, axis=axis)
else:
mean = axis_mean(X, axis=axis, dtype=np.float64)
mean_sq = axis_mean(elem_mul(X, X), axis=axis, dtype=np.float64)
var = mean_sq - mean**2

Check warning on line 44 in src/scanpy/preprocessing/_utils.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/preprocessing/_utils.py#L42-L44

Added lines #L42 - L44 were not covered by tests
# enforce R convention (unbiased estimator) for variance
var *= X.shape[axis] / (X.shape[axis] - 1)
return mean, var

@numba.njit(cache=True,parallel=True)

@numba.njit(cache=True, parallel=True)
def _compute_mean_var(
X: _SupportedArray, axis: Literal[0, 1] = 0,dtype : DTypeLike | None = None
X: _SupportedArray, axis: Literal[0, 1] = 0, dtype: DTypeLike | None = None
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
nthr = numba.get_num_threads()
axis_i = 1 if axis==0 else 0
s=np.zeros((nthr,X.shape[axis_i]),dtype=dtype)
ss=np.zeros((nthr,X.shape[axis_i]),dtype=dtype)
mean=np.zeros(X.shape[axis_i],dtype=dtype)
#std=np.zeros(X.shape[axis_i],dtype=dtype)
var=np.zeros(X.shape[axis_i],dtype=dtype)
axis_i = 1 if axis == 0 else 0
s = np.zeros((nthr, X.shape[axis_i]), dtype=dtype)
ss = np.zeros((nthr, X.shape[axis_i]), dtype=dtype)
mean = np.zeros(X.shape[axis_i], dtype=dtype)
# std=np.zeros(X.shape[axis_i],dtype=dtype)
var = np.zeros(X.shape[axis_i], dtype=dtype)
n = X.shape[axis]
for i in numba.prange(nthr):
for r in range(i,n,nthr):
for r in range(i, n, nthr):
for c in range(X.shape[axis_i]):
v = X[r,c] if axis==0 else X[c,r]
s[i,c]+=v
ss[i,c]+=v*v
v = X[r, c] if axis == 0 else X[c, r]
s[i, c] += v
ss[i, c] += v * v
for c in numba.prange(X.shape[axis_i]):
s0 = s[:,c].sum()
mean[c] = s0/n
var[c] = (ss[:,c].sum() - s0*s0/n)/(n-1)
#std[c]=np.sqrt(var[c])
return mean,var
s0 = s[:, c].sum()
mean[c] = s0 / n
var[c] = (ss[:, c].sum() - s0 * s0 / n) / (n - 1)
# std[c]=np.sqrt(var[c])
return mean, var


def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int):
"""
Expand Down

0 comments on commit e7a4662

Please sign in to comment.