-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x #3099
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3099 +/- ##
==========================================
- Coverage 76.31% 76.31% -0.01%
==========================================
Files 109 109
Lines 12513 12516 +3
==========================================
+ Hits 9549 9551 +2
- Misses 2964 2965 +1
|
Benchmark changes
Comparison: https://github.com/scverse/scanpy/compare/ad657edfb52e9957b9a93b3a16fc8a87852f3f09..e7a466265b08f6973a5cf3fecfc27879104c02f4 More details: https://github.com/scverse/scanpy/pull/3099/checks?check_run_id=26384736173 |
I have some small improvements that I would like to add next week for more precision for larger matrices |
Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com>
Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com>
@ashish615 after doing some benchmarking myself I found out that your solution for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please merge IntelLabs#2
Scale mean variance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fixes the issues
remove casting to match previous behavior Co-authored-by: Severin Dicks <37635888+Intron7@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t see the claimed speedup in the benchmarks, what’s missing?
Also is numba.get_num_threads()
safe? E.g. I think _get_mean_var
is also called in each dask chunks. Will numba.get_num_threads()
return a reasonable number in that case?
Otherwise nice! I’m not a huge fan of how unpythonic numba code looks, but I don’t think anything can be done about that.
# enforce R convention (unbiased estimator) for variance | ||
var *= X.shape[axis] / (X.shape[axis] - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before your change, this line ran unconditionally, now it only runs for the not isinstance(X, np.ndarray)
case. Is that intentional? Then you should mention that in _compute_mean_var
’s docstring.
|
||
|
||
@numba.njit(cache=True, parallel=True) | ||
def _compute_mean_var( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have _get_mean_var
. Maybe rename this to _get_mean_var_ndarray
or _get_mean_var_dense
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can rename the kernel
|
||
@numba.njit(cache=True, parallel=True) | ||
def _compute_mean_var( | ||
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 | |
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads: int = 1 |
if axis == 0: | ||
axis_i = 1 | ||
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | ||
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | ||
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
var = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
n = X.shape[axis] | ||
for i in numba.prange(n_threads): | ||
for r in range(i, n, n_threads): | ||
for c in range(X.shape[axis_i]): | ||
value = X[r, c] | ||
sums[i, c] += value | ||
sums_squared[i, c] += value * value | ||
for c in numba.prange(X.shape[axis_i]): | ||
sum_ = sums[:, c].sum() | ||
mean[c] = sum_ / n | ||
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | ||
else: | ||
axis_i = 0 | ||
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
var = np.zeros(X.shape[axis_i], dtype=np.float64) | ||
for r in numba.prange(X.shape[0]): | ||
for c in range(X.shape[1]): | ||
value = X[r, c] | ||
mean[r] += value | ||
var[r] += value * value | ||
for c in numba.prange(X.shape[0]): | ||
mean[c] = mean[c] / X.shape[1] | ||
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don’t duplicate identical lines.
if axis == 0: | |
axis_i = 1 | |
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | |
var = np.zeros(X.shape[axis_i], dtype=np.float64) | |
n = X.shape[axis] | |
for i in numba.prange(n_threads): | |
for r in range(i, n, n_threads): | |
for c in range(X.shape[axis_i]): | |
value = X[r, c] | |
sums[i, c] += value | |
sums_squared[i, c] += value * value | |
for c in numba.prange(X.shape[axis_i]): | |
sum_ = sums[:, c].sum() | |
mean[c] = sum_ / n | |
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | |
else: | |
axis_i = 0 | |
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | |
var = np.zeros(X.shape[axis_i], dtype=np.float64) | |
for r in numba.prange(X.shape[0]): | |
for c in range(X.shape[1]): | |
value = X[r, c] | |
mean[r] += value | |
var[r] += value * value | |
for c in numba.prange(X.shape[0]): | |
mean[c] = mean[c] / X.shape[1] | |
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) | |
axis_i = 1 - axis | |
mean = np.zeros(X.shape[axis_i], dtype=np.float64) | |
var = np.zeros(X.shape[axis_i], dtype=np.float64) | |
if axis == 0: | |
sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) | |
n = X.shape[axis] | |
for i in numba.prange(n_threads): | |
for r in range(i, n, n_threads): | |
for c in range(X.shape[axis_i]): | |
value = X[r, c] | |
sums[i, c] += value | |
sums_squared[i, c] += value * value | |
for c in numba.prange(X.shape[axis_i]): | |
sum_ = sums[:, c].sum() | |
mean[c] = sum_ / n | |
var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) | |
else: | |
for r in numba.prange(X.shape[0]): | |
for c in range(X.shape[1]): | |
value = X[r, c] | |
mean[r] += value | |
var[r] += value * value | |
for c in numba.prange(X.shape[0]): | |
mean[c] = mean[c] / X.shape[1] | |
var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can slim this down a bit. The two different loops need to be separate though
The function should also work for 1 thread. numba.get_num_threads() is fine it works well with the sparse arrays. But I have no experience with it inside of dask. |
|
||
@numba.njit(cache=True, parallel=True) | ||
def _compute_mean_var( | ||
X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think _SupportedArray
is the right type annotation here. This doesn't run directly on dask.Array
, unless I am misunderstanding something.
Hi,
We are submitting PR for speed up of the _get_mean_var function.
experiment setup : AWS r7i.24xlarge
add timer around _get_mean_var call
scanpy/scanpy/preprocessing/_scale.py
Line 167 in 706d4ef
we can also create _get_mean_var_std function that return std as well so we don't require to compute it in scale function(L168-L169).