-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Improvements to lazy behaviour of xr.cov()
and xr.corr()
#5390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
xarray/core/computation.py
Outdated
demeaned_da_a = da_a - da_a.mean(dim=dim) | ||
demeaned_da_b = da_b - da_b.mean(dim=dim) | ||
# https://github.com/pydata/xarray/issues/4804#issuecomment-760114285 | ||
demeaned_da_ab = (da_a * da_b) - (da_a.mean(dim=dim) * da_b.mean(dim=dim)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be
demeaned_da_ab = (da_a * da_b) - (da_a.mean(dim=dim) * da_b.mean(dim=dim)) | |
demeaned_da_ab = (da_a * da_b).mean(dim=dim) - da_a.mean(dim=dim) * da_b.mean(dim=dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aaah yes sorry, one beer too many..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keewis actually no, it was right the first time I think! Because later on there's a demeaned_da_ab.sum(dim=dim) ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there was actually a typo/error in the original suggestion in #4804
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I think that was right: we're currently doing something like
arr = (x - x.mean()) * (y - y.mean())
arr.sum(...) / valid_count
if we then use the suggestion from #4804 this becomes
arr = (x * y).mean() - x.mean() * y.mean()
arr.sum(...) / valid_count
Edit: actually, no, you're right, the first one does return a array while the second returns a scalar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apparently, this would need to become
x2 = a * b - a * b.mean() - b * a.mean() + a.mean() * b.mean()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aaah yes, good point. I wonder if this still causes the same problem as was noted in the original comment? i.e. it "leads to Dask holding all chunks of x
in memory". I'm not deep enough into Dask to understand the discussion in dask/dask#6674
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that still has that issue.
Right, so actually that comment states that (note the additional .mean()
in the left-hand side):
((a - a.mean()) * (b - b.mean())).mean() = (a * b).mean() - a.mean() * b.mean()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually @keewis , what really matters is if np.testing.assert_allclose(x1.sum(), x2.sum())
raises or not. We're not asserting that they're the same, but just that they give the same result when we sum
over the appropriate dimension (which they do). This is why the tests are passing (so far)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AndrewWilliams3142 See here for a quick illustration of the issue: https://nbviewer.jupyter.org/gist/willirath/a555521c5b979509435fc247d69ea895
Co-authored-by: keewis <keewis@users.noreply.github.com>
Hello @AndrewWilliams3142! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found: There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-05-29 14:35:16 UTC |
@willirath , thanks for your example notebook! I'm still trying to get my head around this a bit though. Say you have da_a = xr.DataArray(
np.array([[1, 2, 3, 4], [1, 0.1, 0.2, 0.3], [2, 3.2, 0.6, 1.8]]),
dims=("space", "time"),
coords=[
("space", ["IA", "IL", "IN"]),
("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
],
).chunk()
da_b = xr.DataArray(
np.array([[0.2, 0.4, 0.6, 2], [15, 10, 5, 1], [1, 3.2, np.nan, 1.8]]),
dims=("space", "time"),
coords=[
("space", ["IA", "IL", "IN"]),
("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
],
).chunk() The original computation in Whereas my alteration now has a graph more like this: Am I correct in thinking that this is a 'better' computational graph? Because the original chunks are not passed onto later points in the computation? |
@AndrewWilliams3142 I think that's right. You can confirm these ideas by profiling a test problem: https://docs.dask.org/en/latest/diagnostics-local.html#example It does seem like with the new version dask will hold on on to |
@AndrewWilliams3142 @dcherian Looks like I broke the first Gist. :( Your Example above does not quite get there, because the Here's a Gist that explains the idea for the correlations: https://nbviewer.jupyter.org/gist/willirath/c5c5274f31c98e8452548e8571158803 With X = xr.DataArray(
darr.random.normal(size=array_size, chunks=chunk_size),
dims=("t", "y", "x"),
name="X",
)
Y = xr.DataArray(
darr.random.normal(size=array_size, chunks=chunk_size),
dims=("t", "y", "x"),
name="Y",
) the "bad" / explicit way of calculating the correlation corr_exp = ((X - X.mean("t")) * (Y - Y.mean("t"))).mean("t") Dask won't release any of the tasks defining The "good" / aggregating way of calculting the correlation corr_agg = (X * Y).mean("t") - X.mean("t") * Y.mean("t") has the following graph |
@willirath this is great stuff, thanks again! So generally it looks like the graph is more efficient when doing operations of the form: (X * Y).mean('time') - (X.mean('time') * Y.mean('time')) than doing ((X - X.mean('time')) * (Y-Y.mean('time'))).mean('time') or like what I've implemented (see screenshot)? intermediate = (X * Y) - (X.mean('time') * Y.mean('time'))
intermediate.mean('time') If so, it seems like the most efficient(?) way to do the computation in _cov_corr() is to combine it all into one line? I can't think of how to do this though... |
# 3. Detrend along the given dim
# 4. Compute covariance along the given dim
# N.B. `skipna=False` is required or there is a bug when computing
# auto-covariance. E.g. Try xr.cov(da,da) for
# da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
def _mean(da):
return da.sum(dim=dim, skipna=True, min_count=1) / (valid_count)
cov = _mean(da_a * da_b) - _mean(da_a.mean(dim=dim) * da_b.mean(dim=dim)) This second term looks very weird to me, it should be a no-op _mean(da_a.mean(dim=dim) * da_b.mean(dim=dim)) is it just
|
I think you'd still have to normalize the second term by |
Shouldn't the following do? cov = (
(da_a * da_b).mean(dim)
- (
da_a.where(da_b.notnull()).mean(dim)
* da_b.where(da_a.notnull()).mean(dim)
)
) (See here: https://nbviewer.jupyter.org/gist/willirath/cfaa8fb1b53fcb8dcb05ddde839c794c ) |
I think the problem with
is that the |
@willirath this is cool, but I think it doesn't explain why the tests fail. Currently @dcherian, I think I've got it to work, but you need to account for the length(s) of the dimension you're calculating the correlation over. (i.e. This latest commit does this, but I'm not sure whether the added complication is worth it yet? Thoughts welcome. def _mean(da):
return (da.sum(dim=dim, skipna=True, min_count=1) / (valid_count))
dim_length = da_a.notnull().sum(dim=dim, skipna=True)
def _mean_detrended_term(da):
return (dim_length * da / (valid_count))
cov = _mean(da_a * da_b) - _mean_detrended_term(da_a.mean(dim=dim) * da_b.mean(dim=dim)) |
It's a shame we let this drop. Is anyone familiar with the current state? Could we resurrect this? |
Following @willirath 's suggestion in #4804, I've changed https://github.com/pydata/xarray/blob/master/xarray/core/computation.py#L1373_L1375 so that Dask doesn't hold all chunks in memory
pre-commit run --all-files