Skip to content

Commit

Permalink
implement dask methods on DataTree (#9670)
Browse files Browse the repository at this point in the history
* implement `compute` and `load`

* also shallow-copy variables

* implement `chunksizes`

* add tests for `load`

* add tests for `chunksizes`

* improve the `load` tests using `DataTree.chunksizes`

* add a test for `compute`

* un-xfail a xpassing test

* implement and test `DataTree.chunk`

* link to `Dataset.load`

Co-authored-by: Tom Nicholas <tom@cworthy.org>

* use `tree.subtree` to get absolute paths

* filter out missing dims before delegating to `Dataset.chunk`

* fix the type hints for `DataTree.chunksizes`

* try using `self.from_dict` instead

* type-hint intermediate test variables

* use `_node_dims` instead

* raise on unknown chunk dim

* check that errors in `chunk` are raised properly

* adapt the docstrings of the new methods

* allow computing / loading unchunked trees

* reword the `chunksizes` properties

* also freeze the top-level chunk sizes

* also reword `DataArray.chunksizes`

* fix a copy-paste error

* same for `NamedArray.chunksizes`

---------

Co-authored-by: Tom Nicholas <tom@cworthy.org>
  • Loading branch information
keewis and TomNicholas authored Oct 24, 2024
1 parent 521b087 commit 1f5889a
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 14 deletions.
6 changes: 4 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,8 +1343,10 @@ def chunks(self) -> tuple[tuple[int, ...], ...] | None:
@property
def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:
"""
Mapping from dimension names to block lengths for this dataarray's data, or None if
the underlying data is not a dask array.
Mapping from dimension names to block lengths for this dataarray's data.
If this dataarray does not contain chunked arrays, the mapping will be empty.
Cannot be modified directly, but can be modified by calling .chunk().
Differs from DataArray.chunks because it returns a mapping of dimensions to chunk shapes
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2658,8 +2658,10 @@ def info(self, buf: IO | None = None) -> None:
@property
def chunks(self) -> Mapping[Hashable, tuple[int, ...]]:
"""
Mapping from dimension names to block lengths for this dataset's data, or None if
the underlying data is not a dask array.
Mapping from dimension names to block lengths for this dataset's data.
If this dataset does not contain chunked arrays, the mapping will be empty.
Cannot be modified directly, but can be modified by calling .chunk().
Same as Dataset.chunksizes, but maintained for backwards compatibility.
Expand All @@ -2675,8 +2677,10 @@ def chunks(self) -> Mapping[Hashable, tuple[int, ...]]:
@property
def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
"""
Mapping from dimension names to block lengths for this dataset's data, or None if
the underlying data is not a dask array.
Mapping from dimension names to block lengths for this dataset's data.
If this dataset does not contain chunked arrays, the mapping will be empty.
Cannot be modified directly, but can be modified by calling .chunk().
Same as Dataset.chunks.
Expand Down
212 changes: 208 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xarray.core._aggregations import DataTreeAggregations
from xarray.core._typed_ops import DataTreeOpsMixin
from xarray.core.alignment import align
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.common import TreeAttrAccessMixin, get_chunksizes
from xarray.core.coordinates import Coordinates, DataTreeCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
Expand Down Expand Up @@ -49,6 +49,8 @@
parse_dims_as_set,
)
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array

try:
from xarray.core.variable import calculate_dimensions
Expand All @@ -68,8 +70,11 @@
ErrorOptions,
ErrorOptionsWithWarn,
NetcdfWriteModes,
T_ChunkDimFreq,
T_ChunksFreq,
ZarrWriteModes,
)
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint

# """
# DEVELOPERS' NOTE
Expand Down Expand Up @@ -862,9 +867,9 @@ def _copy_node(
) -> Self:
"""Copy just one node of a tree."""
new_node = super()._copy_node(inherit=inherit, deep=deep, memo=memo)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)
if deep:
data = data._copy(deep=True, memo=memo)
data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)._copy(
deep=deep, memo=memo
)
new_node._set_node_data(data)
return new_node

Expand Down Expand Up @@ -1896,3 +1901,202 @@ def apply_indexers(dataset, node_indexers):

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
return self._selective_indexing(apply_indexers, indexers)

def load(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this datatree's data
from disk or a remote source into memory and return this datatree.
Unlike compute, the original datatree is modified and returned.
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.
See Also
--------
Dataset.load
dask.compute
"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
path: {
k: v._data
for k, v in node.variables.items()
if is_chunked_array(v._data)
}
for path, node in self.subtree_with_keys
}
flat_lazy_data = {
(path, var_name): array
for path, node in lazy_data.items()
for var_name, array in node.items()
}
if flat_lazy_data:
chunkmanager = get_chunked_array_type(*flat_lazy_data.values())

# evaluate all the chunked arrays simultaneously
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
*flat_lazy_data.values(), **kwargs
)

for (path, var_name), data in zip(
flat_lazy_data, evaluated_data, strict=False
):
self[path].variables[var_name].data = data

# load everything else sequentially
for node in self.subtree:
for k, v in node.variables.items():
if k not in lazy_data:
v.load()

return self

def compute(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this datatree's data
from disk or a remote source into memory and return a new datatree.
Unlike load, the original datatree is left unaltered.
Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
Parameters
----------
**kwargs : dict
Additional keyword arguments passed on to ``dask.compute``.
Returns
-------
object : DataTree
New object with lazy data variables and coordinates as in-memory arrays.
See Also
--------
dask.compute
"""
new = self.copy(deep=False)
return new.load(**kwargs)

@property
def chunksizes(self) -> Mapping[str, Mapping[Hashable, tuple[int, ...]]]:
"""
Mapping from group paths to a mapping of chunksizes.
If there's no chunked data in a group, the corresponding mapping of chunksizes will be empty.
Cannot be modified directly, but can be modified by calling .chunk().
See Also
--------
DataTree.chunk
Dataset.chunksizes
"""
return Frozen(
{
node.path: get_chunksizes(node.variables.values())
for node in self.subtree
}
)

def chunk(
self,
chunks: T_ChunksFreq = {}, # noqa: B006 # {} even though it's technically unsafe, is being used intentionally here (#4667)
name_prefix: str = "xarray-",
token: str | None = None,
lock: bool = False,
inline_array: bool = False,
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs=None,
**chunks_kwargs: T_ChunkDimFreq,
) -> Self:
"""Coerce all arrays in all groups in this tree into dask arrays with the given
chunks.
Non-dask arrays in this tree will be converted to dask arrays. Dask
arrays will be rechunked to the given chunk sizes.
If neither chunks is not provided for one or more dimensions, chunk
sizes along that dimension will not be updated; non-dask arrays will be
converted into dask arrays with a single block.
Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
name_prefix : str, default: "xarray-"
Prefix for the name of any new dask arrays.
token : str, optional
Token uniquely identifying this datatree.
lock : bool, default: False
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
inline_array: bool, default: False
Passed on to :py:func:`dask.array.from_array`, if the array is not
already as dask array.
chunked_array_type: str, optional
Which chunked array type to coerce this datatree's arrays to.
Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system.
Experimental API that should not be relied upon.
from_array_kwargs: dict, optional
Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create
chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg.
For example, with dask as the default chunked array type, this method would pass additional kwargs
to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
**chunks_kwargs : {dim: chunks, ...}, optional
The keyword arguments form of ``chunks``.
One of chunks or chunks_kwargs must be provided
Returns
-------
chunked : xarray.DataTree
See Also
--------
Dataset.chunk
Dataset.chunksizes
xarray.unify_chunks
dask.array.from_array
"""
# don't support deprecated ways of passing chunks
if not isinstance(chunks, Mapping):
raise TypeError(
f"invalid type for chunks: {type(chunks)}. Only mappings are supported."
)
combined_chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")

all_dims = self._get_all_dims()

bad_dims = combined_chunks.keys() - all_dims
if bad_dims:
raise ValueError(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(all_dims)}"
)

rechunked_groups = {
path: node.dataset.chunk(
{
dim: size
for dim, size in combined_chunks.items()
if dim in node._node_dims
},
name_prefix=name_prefix,
token=token,
lock=lock,
inline_array=inline_array,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
)
for path, node in self.subtree_with_keys
}

return self.from_dict(rechunked_groups, name=self.name)
6 changes: 4 additions & 2 deletions xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,10 @@ def chunksizes(
self,
) -> Mapping[_Dim, _Shape]:
"""
Mapping from dimension names to block lengths for this namedArray's data, or None if
the underlying data is not a dask array.
Mapping from dimension names to block lengths for this NamedArray's data.
If this NamedArray does not contain chunked arrays, the mapping will be empty.
Cannot be modified directly, but can be modified by calling .chunk().
Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes
Expand Down
Loading

0 comments on commit 1f5889a

Please sign in to comment.