scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x#3099
scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x#3099ashish615 wants to merge 14 commits intoscverse:mainfrom
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3099 +/- ##
==========================================
- Coverage 76.31% 76.31% -0.01%
==========================================
Files 109 109
Lines 12513 12516 +3
==========================================
+ Hits 9549 9551 +2
- Misses 2964 2965 +1
|
Benchmark changes
Comparison: https://github.com/scverse/scanpy/compare/ad657edfb52e9957b9a93b3a16fc8a87852f3f09..e7a466265b08f6973a5cf3fecfc27879104c02f4 More details: https://github.com/scverse/scanpy/pull/3099/checks?check_run_id=26384736173 |
|
I have some small improvements that I would like to add next week for more precision for larger matrices |
Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com>
Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com>
|
@ashish615 after doing some benchmarking myself I found out that your solution for |
Intron7
left a comment
There was a problem hiding this comment.
Please merge IntelLabs/open-omics-scanpy#2
Scale mean variance
remove casting to match previous behavior Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com>
There was a problem hiding this comment.
I don’t see the claimed speedup in the benchmarks, what’s missing?
Also is numba.get_num_threads() safe? E.g. I think _get_mean_var is also called in each dask chunks. Will numba.get_num_threads() return a reasonable number in that case?
Otherwise nice! I’m not a huge fan of how unpythonic numba code looks, but I don’t think anything can be done about that.
| # enforce R convention (unbiased estimator) for variance | ||
| var *= X.shape[axis] / (X.shape[axis] - 1) |
There was a problem hiding this comment.
Before your change, this line ran unconditionally, now it only runs for the not isinstance(X, np.ndarray) case. Is that intentional? Then you should mention that in _compute_mean_var’s docstring.
|
|
||
|
|
||
| @numba.njit(cache=True, parallel=True) | ||
| def _compute_mean_var( |
There was a problem hiding this comment.
We already have _get_mean_var. Maybe rename this to _get_mean_var_ndarray or _get_mean_var_dense?
There was a problem hiding this comment.
I think we can rename the kernel
|
|
||
| @numba.njit(cache=True, parallel=True) | ||
| def _compute_mean_var( | ||
| X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 |
There was a problem hiding this comment.
| X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 | |
| X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads: int = 1 |
| if axis == 0: | ||
| axis_i = 1 | ||
| sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | ||
| sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | ||
| mean = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
| var = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
| n = X.shape[axis] | ||
| for i in numba.prange(n_threads): | ||
| for r in range(i, n, n_threads): | ||
| for c in range(X.shape[axis_i]): | ||
| value = X[r, c] | ||
| sums[i, c] += value | ||
| sums_squared[i, c] += value * value | ||
| for c in numba.prange(X.shape[axis_i]): | ||
| sum_ = sums[:, c].sum() | ||
| mean[c] = sum_ / n | ||
| var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | ||
| else: | ||
| axis_i = 0 | ||
| mean = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
| var = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
| for r in numba.prange(X.shape[0]): | ||
| for c in range(X.shape[1]): | ||
| value = X[r, c] | ||
| mean[r] += value | ||
| var[r] += value * value | ||
| for c in numba.prange(X.shape[0]): | ||
| mean[c] = mean[c] / X.shape[1] | ||
| var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) |
There was a problem hiding this comment.
Please don’t duplicate identical lines.
| if axis == 0: | |
| axis_i = 1 | |
| sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
| sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
| mean = np.zeros(X.shape[axis_i], dtype=np.float64) | |
| var = np.zeros(X.shape[axis_i], dtype=np.float64) | |
| n = X.shape[axis] | |
| for i in numba.prange(n_threads): | |
| for r in range(i, n, n_threads): | |
| for c in range(X.shape[axis_i]): | |
| value = X[r, c] | |
| sums[i, c] += value | |
| sums_squared[i, c] += value * value | |
| for c in numba.prange(X.shape[axis_i]): | |
| sum_ = sums[:, c].sum() | |
| mean[c] = sum_ / n | |
| var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | |
| else: | |
| axis_i = 0 | |
| mean = np.zeros(X.shape[axis_i], dtype=np.float64) | |
| var = np.zeros(X.shape[axis_i], dtype=np.float64) | |
| for r in numba.prange(X.shape[0]): | |
| for c in range(X.shape[1]): | |
| value = X[r, c] | |
| mean[r] += value | |
| var[r] += value * value | |
| for c in numba.prange(X.shape[0]): | |
| mean[c] = mean[c] / X.shape[1] | |
| var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) | |
| axis_i = 1 - axis | |
| mean = np.zeros(X.shape[axis_i], dtype=np.float64) | |
| var = np.zeros(X.shape[axis_i], dtype=np.float64) | |
| if axis == 0: | |
| sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
| sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
| n = X.shape[axis] | |
| for i in numba.prange(n_threads): | |
| for r in range(i, n, n_threads): | |
| for c in range(X.shape[axis_i]): | |
| value = X[r, c] | |
| sums[i, c] += value | |
| sums_squared[i, c] += value * value | |
| for c in numba.prange(X.shape[axis_i]): | |
| sum_ = sums[:, c].sum() | |
| mean[c] = sum_ / n | |
| var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | |
| else: | |
| for r in numba.prange(X.shape[0]): | |
| for c in range(X.shape[1]): | |
| value = X[r, c] | |
| mean[r] += value | |
| var[r] += value * value | |
| for c in numba.prange(X.shape[0]): | |
| mean[c] = mean[c] / X.shape[1] | |
| var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) |
There was a problem hiding this comment.
I think we can slim this down a bit. The two different loops need to be separate though
|
The function should also work for 1 thread. numba.get_num_threads() is fine it works well with the sparse arrays. But I have no experience with it inside of dask. |
|
|
||
| @numba.njit(cache=True, parallel=True) | ||
| def _compute_mean_var( | ||
| X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 |
There was a problem hiding this comment.
I don't think _SupportedArray is the right type annotation here. This doesn't run directly on dask.Array, unless I am misunderstanding something.
Hi,
We are submitting PR for speed up of the _get_mean_var function.
experiment setup : AWS r7i.24xlarge
add timer around _get_mean_var call
scanpy/scanpy/preprocessing/_scale.py
Line 167 in 706d4ef
we can also create _get_mean_var_std function that return std as well so we don't require to compute it in scale function(L168-L169).