-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize ndrolling nanreduce #4325
Comments
This is already done for Lines 370 to 382 in 1597e3a
This should work for most of the reductions (and is a bit similar to what is done in
I think this should not be too difficult, the thing is that rolling itself is already quite complicated |
@mathause
Agreed. One possible option would be to drop support of bottleneck. |
I just saw that numpy 1.20 introduces |
I think I may have found a way to make the variance/standard deviation calculation more memory efficient, but I don't know enough about writing the sort of code that would be needed for a PR. I basically wrote out the calculation for variance trying to only use the functions that have already been optimsed. Derived from: I coded this up and demonstrate that it uses approximately 10% of the memory as the current %load_ext memory_profiler
import numpy as np
import xarray as xr
temp = xr.DataArray(np.random.randint(0, 10, (5000, 500)), dims=("x", "y"))
def new_var(da, x=10, y=20):
# Defining the re-used parts
roll = da.rolling(x=x, y=y)
mean = roll.mean()
count = roll.count()
# First term: sum of squared values
term1 = (da**2).rolling(x=x, y=y).sum()
# Second term cross term sum
term2 = -2 * mean * roll.sum()
# Third term 'sum' of squared means
term3 = count * mean**2
# Combining into the variance
var = (term1 + term2 + term3) / count
return var
def old_var(da, x=10, y=20):
roll = da.rolling(x=x, y=y)
var = roll.var()
return var
%memit new_var(temp)
%memit old_var(temp)
I wanted to double check that the calculation was working correctly: print((var_o.where(~np.isnan(var_o), 0) == var_n.where(~np.isnan(var_n), 0)).all().values)
print(np.allclose(var_o, var_n, equal_nan = True))
I think the difference here is just due to floating point errors, but maybe someone who knows how to check that in more detail could have a look. The standard deviation can be trivially implemented from this if the approach works. |
Over in #7344 (comment) @shoyer
After some digging, this would involve using "summed area tables" which have been generalized to nD, and can be used to compute all our built-in reductions (except median). Basically we'd store the summed area table (repeated This would be an intermediate level project but we could implement it incrementally (start with cc @aulemahal |
In #4219 we added ndrolling.
However, nanreduce, such as
ds.rolling(x=3, y=2).mean()
callsnp.nanmean
which copies the strided-array into a full-array.This is memory-inefficient.
We can implement inhouse-nanreduce methods for the strided array.
For example, our
.nansum
currently doesmake a strided array -> copy the array -> replace nan by 0 -> sum
but we can do instead
replace nan by 0 -> make a strided array -> sum
This is much more memory efficient.
The text was updated successfully, but these errors were encountered: