Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataTree.persist #9682

Merged
merged 6 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ This interface echoes that of ``xarray.Dataset``.
DataTree.has_attrs
DataTree.is_empty
DataTree.is_hollow
DataTree.chunksizes

Dictionary Interface
--------------------
Expand Down Expand Up @@ -968,6 +969,10 @@ DataTree methods
DataTree.to_dict
DataTree.to_netcdf
DataTree.to_zarr
DataTree.chunk
DataTree.load
DataTree.compute
DataTree.persist

.. ..

Expand Down
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ v.2024.10.1 (unreleased)

New Features
~~~~~~~~~~~~

- Added :py:meth:`DataTree.persist` method (:issue:`9675`, :pull:`9682`).
By `Sam Levang <https://github.com/slevang>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
58 changes: 58 additions & 0 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_default,
drop_dims_from_indexers,
either_dict_or_kwargs,
is_duck_dask_array,
maybe_wrap_array,
parse_dims_as_set,
)
Expand Down Expand Up @@ -1984,6 +1985,63 @@ def compute(self, **kwargs) -> Self:
new = self.copy(deep=False)
return new.load(**kwargs)

def _persist_inplace(self, **kwargs) -> Self:
"""Persist all Dask arrays in memory"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
path: {
k: v._data
for k, v in node.variables.items()
if is_duck_dask_array(v._data)
}
for path, node in self.subtree_with_keys
}
flat_lazy_data = {
(path, var_name): array
for path, node in lazy_data.items()
for var_name, array in node.items()
}
if flat_lazy_data:
import dask

# evaluate all the dask arrays simultaneously
evaluated_data = dask.persist(*flat_lazy_data.values(), **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I lied - there is a least a ToDo to be noted here that this should be routed through the chunkmanager. The chunkmanager doesn't have persist at the moment, but it should, even if dask is the only one to implement it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I just copied dataset persist, which looks like it hasn't been touched since long before the chunkmanager was introduced.

Should we add persist here as a non-abstract method on the chunkmanager classes? Looks pretty straightforward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that would be great

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it could be a separate PR (pre-requisite to this one) but it's not super important to split them.


for (path, var_name), data in zip(
flat_lazy_data, evaluated_data, strict=False
):
self[path].variables[var_name].data = data

return self

def persist(self, **kwargs) -> Self:
"""Trigger computation, keeping data as dask arrays.

This operation can be used to trigger computation on underlying dask
arrays, similar to ``.compute()`` or ``.load()``. However this
operation keeps the data as dask arrays. This is particularly useful
when using the dask.distributed scheduler and you want to load a large
amount of data into distributed memory.
Like compute (but unlike load), the original dataset is left unaltered.


Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.persist``.

Returns
-------
object : DataTree
New object with all dask-backed coordinates and data variables as persisted dask arrays.

See Also
--------
dask.persist
"""
new = self.copy(deep=False)
return new._persist_inplace(**kwargs)

@property
def chunksizes(self) -> Mapping[str, Mapping[Hashable, tuple[int, ...]]]:
"""
Expand Down
51 changes: 51 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,57 @@ def test_compute(self):
assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes"
assert tree.chunksizes == original_chunksizes, "original tree was modified"

def test_persist(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
ds3 = xr.Dataset({"c": ("z", np.arange(4))})
ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))})

def fn(x):
return 2 * x

expected = xr.DataTree.from_dict(
{
"/": fn(ds1).chunk({"x": 5}),
"/group1": fn(ds2).chunk({"y": 3}),
"/group2": fn(ds3).chunk({"z": 2}),
"/group1/subgroup1": fn(ds4).chunk({"x": 5}),
}
)
# Add trivial second layer to the task graph, persist should reduce to one
tree = xr.DataTree.from_dict(
{
"/": fn(ds1.chunk({"x": 5})),
"/group1": fn(ds2.chunk({"y": 3})),
"/group2": fn(ds3.chunk({"z": 2})),
"/group1/subgroup1": fn(ds4.chunk({"x": 5})),
}
)
original_chunksizes = tree.chunksizes
original_hlg_depths = {
node.path: len(node.dataset.__dask_graph__().layers)
for node in tree.subtree
}

actual = tree.persist()
actual_hlg_depths = {
node.path: len(node.dataset.__dask_graph__().layers)
for node in actual.subtree
}

assert_identical(actual, expected)

assert actual.chunksizes == original_chunksizes, "chunksizes were modified"
assert (
tree.chunksizes == original_chunksizes
), "original chunksizes were modified"
assert all(
d == 1 for d in actual_hlg_depths.values()
), "unexpected dask graph depth"
assert all(
d == 2 for d in original_hlg_depths.values()
), "original dask graph was modified"

def test_chunk(self):
ds1 = xr.Dataset({"a": ("x", np.arange(10))})
ds2 = xr.Dataset({"b": ("y", np.arange(5))})
Expand Down
Loading