From 1f5889a924245717a09398ab5855bbcac12ef71c Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Thu, 24 Oct 2024 21:11:28 +0200 Subject: [PATCH] implement `dask` methods on `DataTree` (#9670) * implement `compute` and `load` * also shallow-copy variables * implement `chunksizes` * add tests for `load` * add tests for `chunksizes` * improve the `load` tests using `DataTree.chunksizes` * add a test for `compute` * un-xfail a xpassing test * implement and test `DataTree.chunk` * link to `Dataset.load` Co-authored-by: Tom Nicholas * use `tree.subtree` to get absolute paths * filter out missing dims before delegating to `Dataset.chunk` * fix the type hints for `DataTree.chunksizes` * try using `self.from_dict` instead * type-hint intermediate test variables * use `_node_dims` instead * raise on unknown chunk dim * check that errors in `chunk` are raised properly * adapt the docstrings of the new methods * allow computing / loading unchunked trees * reword the `chunksizes` properties * also freeze the top-level chunk sizes * also reword `DataArray.chunksizes` * fix a copy-paste error * same for `NamedArray.chunksizes` --------- Co-authored-by: Tom Nicholas --- xarray/core/dataarray.py | 6 +- xarray/core/dataset.py | 12 +- xarray/core/datatree.py | 212 +++++++++++++++++++++++++++++++++- xarray/namedarray/core.py | 6 +- xarray/tests/test_datatree.py | 120 ++++++++++++++++++- 5 files changed, 342 insertions(+), 14 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 826acc7c7e9..db7824d8c90 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1343,8 +1343,10 @@ def chunks(self) -> tuple[tuple[int, ...], ...] | None: @property def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: """ - Mapping from dimension names to block lengths for this dataarray's data, or None if - the underlying data is not a dask array. + Mapping from dimension names to block lengths for this dataarray's data. + + If this dataarray does not contain chunked arrays, the mapping will be empty. + Cannot be modified directly, but can be modified by calling .chunk(). Differs from DataArray.chunks because it returns a mapping of dimensions to chunk shapes diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ab3e2901dcf..66ceea17b91 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2658,8 +2658,10 @@ def info(self, buf: IO | None = None) -> None: @property def chunks(self) -> Mapping[Hashable, tuple[int, ...]]: """ - Mapping from dimension names to block lengths for this dataset's data, or None if - the underlying data is not a dask array. + Mapping from dimension names to block lengths for this dataset's data. + + If this dataset does not contain chunked arrays, the mapping will be empty. + Cannot be modified directly, but can be modified by calling .chunk(). Same as Dataset.chunksizes, but maintained for backwards compatibility. @@ -2675,8 +2677,10 @@ def chunks(self) -> Mapping[Hashable, tuple[int, ...]]: @property def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: """ - Mapping from dimension names to block lengths for this dataset's data, or None if - the underlying data is not a dask array. + Mapping from dimension names to block lengths for this dataset's data. + + If this dataset does not contain chunked arrays, the mapping will be empty. + Cannot be modified directly, but can be modified by calling .chunk(). Same as Dataset.chunks. diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e9e30da5f05..d4ee2621557 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -18,7 +18,7 @@ from xarray.core._aggregations import DataTreeAggregations from xarray.core._typed_ops import DataTreeOpsMixin from xarray.core.alignment import align -from xarray.core.common import TreeAttrAccessMixin +from xarray.core.common import TreeAttrAccessMixin, get_chunksizes from xarray.core.coordinates import Coordinates, DataTreeCoordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, DataVariables @@ -49,6 +49,8 @@ parse_dims_as_set, ) from xarray.core.variable import Variable +from xarray.namedarray.parallelcompat import get_chunked_array_type +from xarray.namedarray.pycompat import is_chunked_array try: from xarray.core.variable import calculate_dimensions @@ -68,8 +70,11 @@ ErrorOptions, ErrorOptionsWithWarn, NetcdfWriteModes, + T_ChunkDimFreq, + T_ChunksFreq, ZarrWriteModes, ) + from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint # """ # DEVELOPERS' NOTE @@ -862,9 +867,9 @@ def _copy_node( ) -> Self: """Copy just one node of a tree.""" new_node = super()._copy_node(inherit=inherit, deep=deep, memo=memo) - data = self._to_dataset_view(rebuild_dims=False, inherit=inherit) - if deep: - data = data._copy(deep=True, memo=memo) + data = self._to_dataset_view(rebuild_dims=False, inherit=inherit)._copy( + deep=deep, memo=memo + ) new_node._set_node_data(data) return new_node @@ -1896,3 +1901,202 @@ def apply_indexers(dataset, node_indexers): indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel") return self._selective_indexing(apply_indexers, indexers) + + def load(self, **kwargs) -> Self: + """Manually trigger loading and/or computation of this datatree's data + from disk or a remote source into memory and return this datatree. + Unlike compute, the original datatree is modified and returned. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + See Also + -------- + Dataset.load + dask.compute + """ + # access .data to coerce everything to numpy or dask arrays + lazy_data = { + path: { + k: v._data + for k, v in node.variables.items() + if is_chunked_array(v._data) + } + for path, node in self.subtree_with_keys + } + flat_lazy_data = { + (path, var_name): array + for path, node in lazy_data.items() + for var_name, array in node.items() + } + if flat_lazy_data: + chunkmanager = get_chunked_array_type(*flat_lazy_data.values()) + + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *flat_lazy_data.values(), **kwargs + ) + + for (path, var_name), data in zip( + flat_lazy_data, evaluated_data, strict=False + ): + self[path].variables[var_name].data = data + + # load everything else sequentially + for node in self.subtree: + for k, v in node.variables.items(): + if k not in lazy_data: + v.load() + + return self + + def compute(self, **kwargs) -> Self: + """Manually trigger loading and/or computation of this datatree's data + from disk or a remote source into memory and return a new datatree. + Unlike load, the original datatree is left unaltered. + + Normally, it should not be necessary to call this method in user code, + because all xarray functions should either work on deferred data or + load data automatically. However, this method can be necessary when + working with many file objects on disk. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments passed on to ``dask.compute``. + + Returns + ------- + object : DataTree + New object with lazy data variables and coordinates as in-memory arrays. + + See Also + -------- + dask.compute + """ + new = self.copy(deep=False) + return new.load(**kwargs) + + @property + def chunksizes(self) -> Mapping[str, Mapping[Hashable, tuple[int, ...]]]: + """ + Mapping from group paths to a mapping of chunksizes. + + If there's no chunked data in a group, the corresponding mapping of chunksizes will be empty. + + Cannot be modified directly, but can be modified by calling .chunk(). + + See Also + -------- + DataTree.chunk + Dataset.chunksizes + """ + return Frozen( + { + node.path: get_chunksizes(node.variables.values()) + for node in self.subtree + } + ) + + def chunk( + self, + chunks: T_ChunksFreq = {}, # noqa: B006 # {} even though it's technically unsafe, is being used intentionally here (#4667) + name_prefix: str = "xarray-", + token: str | None = None, + lock: bool = False, + inline_array: bool = False, + chunked_array_type: str | ChunkManagerEntrypoint | None = None, + from_array_kwargs=None, + **chunks_kwargs: T_ChunkDimFreq, + ) -> Self: + """Coerce all arrays in all groups in this tree into dask arrays with the given + chunks. + + Non-dask arrays in this tree will be converted to dask arrays. Dask + arrays will be rechunked to the given chunk sizes. + + If neither chunks is not provided for one or more dimensions, chunk + sizes along that dimension will not be updated; non-dask arrays will be + converted into dask arrays with a single block. + + Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted. + + Parameters + ---------- + chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional + Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``. + name_prefix : str, default: "xarray-" + Prefix for the name of any new dask arrays. + token : str, optional + Token uniquely identifying this datatree. + lock : bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + inline_array: bool, default: False + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. + chunked_array_type: str, optional + Which chunked array type to coerce this datatree's arrays to. + Defaults to 'dask' if installed, else whatever is registered via the `ChunkManagerEntryPoint` system. + Experimental API that should not be relied upon. + from_array_kwargs: dict, optional + Additional keyword arguments passed on to the `ChunkManagerEntrypoint.from_array` method used to create + chunked arrays, via whichever chunk manager is specified through the `chunked_array_type` kwarg. + For example, with dask as the default chunked array type, this method would pass additional kwargs + to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + **chunks_kwargs : {dim: chunks, ...}, optional + The keyword arguments form of ``chunks``. + One of chunks or chunks_kwargs must be provided + + Returns + ------- + chunked : xarray.DataTree + + See Also + -------- + Dataset.chunk + Dataset.chunksizes + xarray.unify_chunks + dask.array.from_array + """ + # don't support deprecated ways of passing chunks + if not isinstance(chunks, Mapping): + raise TypeError( + f"invalid type for chunks: {type(chunks)}. Only mappings are supported." + ) + combined_chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + + all_dims = self._get_all_dims() + + bad_dims = combined_chunks.keys() - all_dims + if bad_dims: + raise ValueError( + f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(all_dims)}" + ) + + rechunked_groups = { + path: node.dataset.chunk( + { + dim: size + for dim, size in combined_chunks.items() + if dim in node._node_dims + }, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, + chunked_array_type=chunked_array_type, + from_array_kwargs=from_array_kwargs, + ) + for path, node in self.subtree_with_keys + } + + return self.from_dict(rechunked_groups, name=self.name) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 6f5ed671de8..b753d26d622 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -725,8 +725,10 @@ def chunksizes( self, ) -> Mapping[_Dim, _Shape]: """ - Mapping from dimension names to block lengths for this namedArray's data, or None if - the underlying data is not a dask array. + Mapping from dimension names to block lengths for this NamedArray's data. + + If this NamedArray does not contain chunked arrays, the mapping will be empty. + Cannot be modified directly, but can be modified by calling .chunk(). Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 3be3fbd620d..1fa93d9853d 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1,6 +1,7 @@ import re import sys import typing +from collections.abc import Mapping from copy import copy, deepcopy from textwrap import dedent @@ -13,7 +14,12 @@ from xarray.core.datatree import DataTree from xarray.core.treenode import NotFoundInTreeError from xarray.testing import assert_equal, assert_identical -from xarray.tests import assert_array_equal, create_test_data, source_ndarray +from xarray.tests import ( + assert_array_equal, + create_test_data, + requires_dask, + source_ndarray, +) ON_WINDOWS = sys.platform == "win32" @@ -858,7 +864,6 @@ def test_to_dict(self): actual = DataTree.from_dict(tree.children["a"].to_dict(relative=True)) assert_identical(expected, actual) - @pytest.mark.xfail def test_roundtrip_unnamed_root(self, simple_datatree) -> None: # See GH81 @@ -2195,3 +2200,114 @@ def test_close_dataset(self, tree_and_closers): # with tree: # pass + + +@requires_dask +class TestDask: + def test_chunksizes(self): + ds1 = xr.Dataset({"a": ("x", np.arange(10))}) + ds2 = xr.Dataset({"b": ("y", np.arange(5))}) + ds3 = xr.Dataset({"c": ("z", np.arange(4))}) + ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) + + groups = { + "/": ds1.chunk({"x": 5}), + "/group1": ds2.chunk({"y": 3}), + "/group2": ds3.chunk({"z": 2}), + "/group1/subgroup1": ds4.chunk({"x": 5}), + } + + tree = xr.DataTree.from_dict(groups) + + expected_chunksizes = {path: node.chunksizes for path, node in groups.items()} + + assert tree.chunksizes == expected_chunksizes + + def test_load(self): + ds1 = xr.Dataset({"a": ("x", np.arange(10))}) + ds2 = xr.Dataset({"b": ("y", np.arange(5))}) + ds3 = xr.Dataset({"c": ("z", np.arange(4))}) + ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) + + groups = {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} + + expected = xr.DataTree.from_dict(groups) + tree = xr.DataTree.from_dict( + { + "/": ds1.chunk({"x": 5}), + "/group1": ds2.chunk({"y": 3}), + "/group2": ds3.chunk({"z": 2}), + "/group1/subgroup1": ds4.chunk({"x": 5}), + } + ) + expected_chunksizes: Mapping[str, Mapping] + expected_chunksizes = {node.path: {} for node in tree.subtree} + actual = tree.load() + + assert_identical(actual, expected) + assert tree.chunksizes == expected_chunksizes + assert actual.chunksizes == expected_chunksizes + + tree = xr.DataTree.from_dict(groups) + actual = tree.load() + assert_identical(actual, expected) + assert actual.chunksizes == expected_chunksizes + + def test_compute(self): + ds1 = xr.Dataset({"a": ("x", np.arange(10))}) + ds2 = xr.Dataset({"b": ("y", np.arange(5))}) + ds3 = xr.Dataset({"c": ("z", np.arange(4))}) + ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) + + expected = xr.DataTree.from_dict( + {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} + ) + tree = xr.DataTree.from_dict( + { + "/": ds1.chunk({"x": 5}), + "/group1": ds2.chunk({"y": 3}), + "/group2": ds3.chunk({"z": 2}), + "/group1/subgroup1": ds4.chunk({"x": 5}), + } + ) + original_chunksizes = tree.chunksizes + expected_chunksizes: Mapping[str, Mapping] + expected_chunksizes = {node.path: {} for node in tree.subtree} + actual = tree.compute() + + assert_identical(actual, expected) + + assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes" + assert tree.chunksizes == original_chunksizes, "original tree was modified" + + def test_chunk(self): + ds1 = xr.Dataset({"a": ("x", np.arange(10))}) + ds2 = xr.Dataset({"b": ("y", np.arange(5))}) + ds3 = xr.Dataset({"c": ("z", np.arange(4))}) + ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) + + expected = xr.DataTree.from_dict( + { + "/": ds1.chunk({"x": 5}), + "/group1": ds2.chunk({"y": 3}), + "/group2": ds3.chunk({"z": 2}), + "/group1/subgroup1": ds4.chunk({"x": 5}), + } + ) + + tree = xr.DataTree.from_dict( + {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} + ) + actual = tree.chunk({"x": 5, "y": 3, "z": 2}) + + assert_identical(actual, expected) + assert actual.chunksizes == expected.chunksizes + + with pytest.raises(TypeError, match="invalid type"): + tree.chunk(None) + + with pytest.raises(TypeError, match="invalid type"): + tree.chunk((1, 2)) + + with pytest.raises(ValueError, match="not found in data dimensions"): + tree.chunk({"u": 2})