Skip to content

Improvements to lazy behaviour of xr.cov() and xr.corr() #5390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

AndrewILWilliams
Copy link
Contributor

Following @willirath 's suggestion in #4804, I've changed https://github.com/pydata/xarray/blob/master/xarray/core/computation.py#L1373_L1375 so that Dask doesn't hold all chunks in memory

demeaned_da_a = da_a - da_a.mean(dim=dim)
demeaned_da_b = da_b - da_b.mean(dim=dim)
# https://github.com/pydata/xarray/issues/4804#issuecomment-760114285
demeaned_da_ab = (da_a * da_b) - (da_a.mean(dim=dim) * da_b.mean(dim=dim))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be

Suggested change
demeaned_da_ab = (da_a * da_b) - (da_a.mean(dim=dim) * da_b.mean(dim=dim))
demeaned_da_ab = (da_a * da_b).mean(dim=dim) - da_a.mean(dim=dim) * da_b.mean(dim=dim)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaah yes sorry, one beer too many..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keewis actually no, it was right the first time I think! Because later on there's a demeaned_da_ab.sum(dim=dim) ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there was actually a typo/error in the original suggestion in #4804

Copy link
Collaborator

@keewis keewis May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I think that was right: we're currently doing something like

arr = (x - x.mean()) * (y - y.mean())
arr.sum(...) / valid_count

if we then use the suggestion from #4804 this becomes

arr = (x * y).mean() - x.mean() * y.mean()
arr.sum(...) / valid_count

Edit: actually, no, you're right, the first one does return a array while the second returns a scalar

Copy link
Collaborator

@keewis keewis May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently, this would need to become

x2 = a * b - a * b.mean() - b * a.mean() + a.mean() * b.mean()

Copy link
Contributor Author

@AndrewILWilliams AndrewILWilliams May 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaah yes, good point. I wonder if this still causes the same problem as was noted in the original comment? i.e. it "leads to Dask holding all chunks of x in memory". I'm not deep enough into Dask to understand the discussion in dask/dask#6674

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that still has that issue.

Right, so actually that comment states that (note the additional .mean() in the left-hand side):

((a - a.mean()) * (b - b.mean())).mean() = (a * b).mean() - a.mean() * b.mean()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually @keewis , what really matters is if np.testing.assert_allclose(x1.sum(), x2.sum()) raises or not. We're not asserting that they're the same, but just that they give the same result when we sum over the appropriate dimension (which they do). This is why the tests are passing (so far)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AndrewWilliams3142 See here for a quick illustration of the issue: https://nbviewer.jupyter.org/gist/willirath/a555521c5b979509435fc247d69ea895

AndrewILWilliams and others added 2 commits May 27, 2021 21:28
Co-authored-by: keewis <keewis@users.noreply.github.com>
@pep8speaks
Copy link

pep8speaks commented May 27, 2021

Hello @AndrewWilliams3142! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-05-29 14:35:16 UTC

@AndrewILWilliams
Copy link
Contributor Author

AndrewILWilliams commented May 28, 2021

@willirath , thanks for your example notebook! I'm still trying to get my head around this a bit though.

Say you have da_a and da_b defined as:

da_a = xr.DataArray(
        np.array([[1, 2, 3, 4], [1, 0.1, 0.2, 0.3], [2, 3.2, 0.6, 1.8]]),
        dims=("space", "time"),
        coords=[
            ("space", ["IA", "IL", "IN"]),
            ("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
        ],
    ).chunk()

da_b = xr.DataArray(
        np.array([[0.2, 0.4, 0.6, 2], [15, 10, 5, 1], [1, 3.2, np.nan, 1.8]]),
        dims=("space", "time"),
        coords=[
            ("space", ["IA", "IL", "IN"]),
            ("time", pd.date_range("2000-01-01", freq="1D", periods=4)),
        ],
    ).chunk()

The original computation in _cov_corr has a graph something like:
image

Whereas my alteration now has a graph more like this:
image

Am I correct in thinking that this is a 'better' computational graph? Because the original chunks are not passed onto later points in the computation?

@dcherian
Copy link
Contributor

@AndrewWilliams3142 I think that's right. You can confirm these ideas by profiling a test problem: https://docs.dask.org/en/latest/diagnostics-local.html#example

It does seem like with the new version dask will hold on on to da_a*da_b for a while, which is an improvement over holding da_a and da_b separately for a while

@willirath
Copy link
Contributor

@AndrewWilliams3142 @dcherian Looks like I broke the first Gist. :(

Your Example above does not quite get there, because the xr.DataArray(np...).chunk() just leads to one chunk per data array.

Here's a Gist that explains the idea for the correlations: https://nbviewer.jupyter.org/gist/willirath/c5c5274f31c98e8452548e8571158803

With

X = xr.DataArray(
    darr.random.normal(size=array_size, chunks=chunk_size),
    dims=("t", "y", "x"),
    name="X",
)

Y = xr.DataArray(
    darr.random.normal(size=array_size, chunks=chunk_size),
    dims=("t", "y", "x"),
    name="Y",
)

the "bad" / explicit way of calculating the correlation

corr_exp = ((X - X.mean("t")) * (Y - Y.mean("t"))).mean("t")

leads to a graph like this:
image

Dask won't release any of the tasks defining X and Y until the marked substraction tasks are done.

The "good" / aggregating way of calculting the correlation

corr_agg = (X * Y).mean("t") - X.mean("t") * Y.mean("t")

has the following graph
image
where the marked multiplication and mean_chunk tasks are acting on only pairs of chunks and individual chunks and then release the original chunks of X and Y. This graph can be evaluated with a much smaller memory foot print than the other one. (It's not certain that this is always leading to lower memory use, however. But this is a different issue ...)

@AndrewILWilliams
Copy link
Contributor Author

AndrewILWilliams commented May 28, 2021

@willirath this is great stuff, thanks again! So generally it looks like the graph is more efficient when doing operations of the form:

(X * Y).mean('time') - (X.mean('time') * Y.mean('time'))

than doing

((X - X.mean('time')) * (Y-Y.mean('time'))).mean('time')

or like what I've implemented (see screenshot)?

intermediate = (X * Y) - (X.mean('time') * Y.mean('time'))

intermediate.mean('time')

image

If so, it seems like the most efficient(?) way to do the computation in _cov_corr() is to combine it all into one line? I can't think of how to do this though...

@dcherian
Copy link
Contributor

dcherian commented May 28, 2021

    # 3. Detrend along the given dim
	# 4. Compute covariance along the given dim
    # N.B. `skipna=False` is required or there is a bug when computing
    # auto-covariance. E.g. Try xr.cov(da,da) for
    # da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
    def _mean(da):
		return da.sum(dim=dim, skipna=True, min_count=1) / (valid_count)
    cov = _mean(da_a * da_b) - _mean(da_a.mean(dim=dim) * da_b.mean(dim=dim))

This second term looks very weird to me, it should be a no-op

_mean(da_a.mean(dim=dim) * da_b.mean(dim=dim))

is it just

cov = _mean(da_a * da_b) - da_a.mean(dim=dim) * da_b.mean(dim=dim)

@AndrewILWilliams
Copy link
Contributor Author

AndrewILWilliams commented May 28, 2021

is it just

cov = _mean(da_a * da_b) - da_a.mean(dim=dim) * da_b.mean(dim=dim)

I think you'd still have to normalize the second term by 1 / (valid_count). However, I just tried both of these approaches and neither pass the test suite, so we may need to do more thinking...

@willirath
Copy link
Contributor

Shouldn't the following do?

cov = (
    (da_a * da_b).mean(dim) 
    - (
        da_a.where(da_b.notnull()).mean(dim)
        * da_b.where(da_a.notnull()).mean(dim)
    )
)

(See here: https://nbviewer.jupyter.org/gist/willirath/cfaa8fb1b53fcb8dcb05ddde839c794c )

@willirath
Copy link
Contributor

willirath commented May 29, 2021

I think the problem with

cov = _mean(da_a * da_b) - da_a.mean(dim=dim) * da_b.mean(dim=dim)

is that the da_a.mean() and the da_b.mean() calls don't know about each other's missing data.

@AndrewILWilliams
Copy link
Contributor Author

AndrewILWilliams commented May 29, 2021

@willirath this is cool, but I think it doesn't explain why the tests fail. Currently da_a.mean() and the da_b.mean() calls do know about each other's missing data! That's what we're doing in these lines.

@dcherian, I think I've got it to work, but you need to account for the length(s) of the dimension you're calculating the correlation over. (i.e. (da-da.mean('time')).sum('time') is not the same as da.sum('time') - da.mean('time') because you should actually do da.sum('time') - da.mean('time')*length_of_time_dim)

This latest commit does this, but I'm not sure whether the added complication is worth it yet? Thoughts welcome.

def _mean(da):
        return (da.sum(dim=dim, skipna=True, min_count=1) / (valid_count))

dim_length = da_a.notnull().sum(dim=dim, skipna=True)
def _mean_detrended_term(da):
    return (dim_length * da / (valid_count))

cov = _mean(da_a * da_b) - _mean_detrended_term(da_a.mean(dim=dim) * da_b.mean(dim=dim))

@max-sixty
Copy link
Collaborator

It's a shame we let this drop.

Is anyone familiar with the current state? Could we resurrect this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants