Skip to content

Commit

Permalink
Merge pull request #2 from scverse/scale-mean-variance
Browse files Browse the repository at this point in the history
Scale mean variance
  • Loading branch information
ashish615 authored Jun 26, 2024
2 parents 06c0968 + 53d40dd commit 63404f2
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions src/scanpy/preprocessing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def _get_mean_var(
X: _SupportedArray, *, axis: Literal[0, 1] = 0
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
if isinstance(X, np.ndarray):
mean, var = _compute_mean_var(X, axis=axis)
n_threads = numba.get_num_threads()
mean, var = _compute_mean_var(X, axis=axis, n_threads=n_threads)
else:
if isinstance(X, sparse.spmatrix):
mean, var = sparse_mean_variance_axis(X, axis=axis)
Expand All @@ -49,25 +50,38 @@ def _get_mean_var(

@numba.njit(cache=True, parallel=True)
def _compute_mean_var(
X: _SupportedArray, axis: Literal[0, 1] = 0
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
nthr = numba.get_num_threads()
axis_i = 1 if axis == 0 else 0
s = np.zeros((nthr, X.shape[axis_i]), dtype=dtype)
ss = np.zeros((nthr, X.shape[axis_i]), dtype=dtype)
mean = np.zeros(X.shape[axis_i], dtype=dtype)
var = np.zeros(X.shape[axis_i], dtype=dtype)
n = X.shape[axis]
for i in numba.prange(nthr):
for r in range(i, n, nthr):
for c in range(X.shape[axis_i]):
v = X[r, c] if axis == 0 else X[c, r]
s[i, c] += v
ss[i, c] += v * v
for c in numba.prange(X.shape[axis_i]):
s0 = s[:, c].sum()
mean[c] = s0 / n
var[c] = (ss[:, c].sum() - s0 * s0 / n) / (n - 1)
if axis == 0:
axis_i = 1
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64)
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
n = X.shape[axis]
for i in numba.prange(n_threads):
for r in range(i, n, n_threads):
for c in range(X.shape[axis_i]):
value = np.float64(X[r, c])
sums[i, c] += value
sums_squared[i, c] += value * value
for c in numba.prange(X.shape[axis_i]):
sum_ = sums[:, c].sum()
mean[c] = sum_ / n
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1)
else:
axis_i = 0
mean = np.zeros(X.shape[axis_i], dtype=np.float64)
var = np.zeros(X.shape[axis_i], dtype=np.float64)
for r in numba.prange(X.shape[0]):
for c in range(X.shape[1]):
value = np.float64(X[r, c])
mean[r] += value
var[r] += value * value
for c in numba.prange(X.shape[0]):
mean[c] = mean[c] / X.shape[1]
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1)

return mean, var


Expand Down

0 comments on commit 63404f2

Please sign in to comment.