Skip to content

Commit e58edcc

Browse files
authored
Re-implement map_over_datasets using group_subtrees (#9636)
* Add zip_subtrees for paired iteration over DataTrees This should be used for implementing DataTree arithmetic inside map_over_datasets, so the result does not depend on the order in which child nodes are defined. I have also added a minimal implementation of breadth-first-search with an explicit queue the current recursion based solution in xarray.core.iterators (which has been removed). The new implementation is also slightly faster in my microbenchmark: In [1]: import xarray as xr In [2]: tree = xr.DataTree.from_dict({f"/x{i}": None for i in range(100)}) In [3]: %timeit _ = list(tree.subtree) # on main 87.2 μs ± 394 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) # with this branch 55.1 μs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) * fix pytype error * Re-implement map_over_datasets The main changes: - It is implemented using zip_subtrees, which means it should properly handle DataTrees where the nodes are defined in a different order. - For simplicity, I removed handling of `**kwargs`, in order to preserve some flexibility for adding keyword arugments. - I removed automatic skipping of empty nodes, because there are almost assuredly cases where that would make sense. This could be restored with a option keyword arugment. * fix typing of map_over_datasets * add group_subtrees * wip fixes * update isomorphic * documentation and API change for map_over_datasets * mypy fixes * fix test * diff formatting * more mypy * doc fix * more doc fix * add api docs * add utility for joining path on windows * docstring * add an overload for two return values from map_over_datasets * partial fixes per review * fixes per review * remove a couple of xfails
1 parent 8f6e45b commit e58edcc

19 files changed

+627
-1069
lines changed

DATATREE_MIGRATION_GUIDE.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ This guide is for previous users of the prototype `datatree.DataTree` class in t
77
> [!IMPORTANT]
88
> There are breaking changes! You should not expect that code written with `xarray-contrib/datatree` will work without any modifications. At the absolute minimum you will need to change the top-level import statement, but there are other changes too.
99
10-
We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself, integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.
10+
We have made various changes compared to the prototype version. These can be split into three categories: data model changes, which affect the hierarchal structure itself; integration with xarray's IO backends; and minor API changes, which mostly consist of renaming methods to be more self-consistent.
1111

1212
### Data model changes
1313

@@ -17,6 +17,8 @@ These alignment checks happen at tree construction time, meaning there are some
1717

1818
The alignment checks allowed us to add "Coordinate Inheritance", a much-requested feature where indexed coordinate variables are now "inherited" down to child nodes. This allows you to define common coordinates in a parent group that are then automatically available on every child node. The distinction between a locally-defined coordinate variables and an inherited coordinate that was defined on a parent node is reflected in the `DataTree.__repr__`. Generally if you prefer not to have these variables be inherited you can get more similar behaviour to the old `datatree` package by removing indexes from coordinates, as this prevents inheritance.
1919

20+
Tree structure checks between multiple trees (i.e., `DataTree.isomorophic`) and pairing of nodes in arithmetic has also changed. Nodes are now matched (with `xarray.group_subtrees`) based on their relative paths, without regard to the order in which child nodes are defined.
21+
2022
For further documentation see the page in the user guide on Hierarchical Data.
2123

2224
### Integrated backends

doc/api.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,17 @@ Manipulate the contents of a single ``DataTree`` node.
749749
DataTree.assign
750750
DataTree.drop_nodes
751751

752+
DataTree Operations
753+
-------------------
754+
755+
Apply operations over multiple ``DataTree`` objects.
756+
757+
.. autosummary::
758+
:toctree: generated/
759+
760+
map_over_datasets
761+
group_subtrees
762+
752763
Comparisons
753764
-----------
754765

@@ -954,7 +965,6 @@ DataTree methods
954965

955966
open_datatree
956967
open_groups
957-
map_over_datasets
958968
DataTree.to_dict
959969
DataTree.to_netcdf
960970
DataTree.to_zarr

doc/user-guide/hierarchical-data.rst

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -362,21 +362,26 @@ This returns an iterable of nodes, which yields them in depth-first order.
362362
for node in vertebrates.subtree:
363363
print(node.path)
364364
365-
A very useful pattern is to use :py:class:`~xarray.DataTree.subtree` conjunction with the :py:class:`~xarray.DataTree.path` property to manipulate the nodes however you wish,
366-
then rebuild a new tree using :py:meth:`xarray.DataTree.from_dict()`.
365+
Similarly, :py:class:`~xarray.DataTree.subtree_with_keys` returns an iterable of
366+
relative paths and corresponding nodes.
367367

368+
A very useful pattern is to iterate over :py:class:`~xarray.DataTree.subtree_with_keys`
369+
to manipulate nodes however you wish, then rebuild a new tree using
370+
:py:meth:`xarray.DataTree.from_dict()`.
368371
For example, we could keep only the nodes containing data by looping over all nodes,
369372
checking if they contain any data using :py:class:`~xarray.DataTree.has_data`,
370373
then rebuilding a new tree using only the paths of those nodes:
371374

372375
.. ipython:: python
373376
374-
non_empty_nodes = {node.path: node.dataset for node in dt.subtree if node.has_data}
377+
non_empty_nodes = {
378+
path: node.dataset for path, node in dt.subtree_with_keys if node.has_data
379+
}
375380
xr.DataTree.from_dict(non_empty_nodes)
376381
377382
You can see this tree is similar to the ``dt`` object above, except that it is missing the empty nodes ``a/c`` and ``a/c/d``.
378383

379-
(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`~xarray.DataTree.from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.root.name)``.)
384+
(If you want to keep the name of the root node, you will need to add the ``name`` kwarg to :py:class:`~xarray.DataTree.from_dict`, i.e. ``DataTree.from_dict(non_empty_nodes, name=dt.name)``.)
380385

381386
.. _manipulating trees:
382387

@@ -573,38 +578,78 @@ Then calculate the RMS value of these signals:
573578
574579
.. _multiple trees:
575580

576-
We can also use the :py:meth:`~xarray.map_over_datasets` decorator to promote a function which accepts datasets into one which
577-
accepts datatrees.
581+
We can also use :py:func:`~xarray.map_over_datasets` to apply a function over
582+
the data in multiple trees, by passing the trees as positional arguments.
578583

579584
Operating on Multiple Trees
580585
---------------------------
581586

582587
The examples so far have involved mapping functions or methods over the nodes of a single tree,
583588
but we can generalize this to mapping functions over multiple trees at once.
584589

590+
Iterating Over Multiple Trees
591+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
592+
593+
To iterate over the corresponding nodes in multiple trees, use
594+
:py:func:`~xarray.group_subtrees` instead of
595+
:py:class:`~xarray.DataTree.subtree_with_keys`. This combines well with
596+
:py:meth:`xarray.DataTree.from_dict()` to build a new tree:
597+
598+
.. ipython:: python
599+
600+
dt1 = xr.DataTree.from_dict({"a": xr.Dataset({"x": 1}), "b": xr.Dataset({"x": 2})})
601+
dt2 = xr.DataTree.from_dict(
602+
{"a": xr.Dataset({"x": 10}), "b": xr.Dataset({"x": 20})}
603+
)
604+
result = {}
605+
for path, (node1, node2) in xr.group_subtrees(dt1, dt2):
606+
result[path] = node1.dataset + node2.dataset
607+
xr.DataTree.from_dict(result)
608+
609+
Alternatively, you apply a function directly to paired datasets at every node
610+
using :py:func:`xarray.map_over_datasets`:
611+
612+
.. ipython:: python
613+
614+
xr.map_over_datasets(lambda x, y: x + y, dt1, dt2)
615+
585616
Comparing Trees for Isomorphism
586617
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
587618

588619
For it to make sense to map a single non-unary function over the nodes of multiple trees at once,
589-
each tree needs to have the same structure. Specifically two trees can only be considered similar, or "isomorphic",
590-
if they have the same number of nodes, and each corresponding node has the same number of children.
591-
We can check if any two trees are isomorphic using the :py:meth:`~xarray.DataTree.isomorphic` method.
620+
each tree needs to have the same structure. Specifically two trees can only be considered similar,
621+
or "isomorphic", if the full paths to all of their descendent nodes are the same.
622+
623+
Applying :py:func:`~xarray.group_subtrees` to trees with different structures
624+
raises :py:class:`~xarray.TreeIsomorphismError`:
592625

593626
.. ipython:: python
594627
:okexcept:
595628
596-
dt1 = xr.DataTree.from_dict({"a": None, "a/b": None})
597-
dt2 = xr.DataTree.from_dict({"a": None})
598-
dt1.isomorphic(dt2)
629+
tree = xr.DataTree.from_dict({"a": None, "a/b": None, "a/c": None})
630+
simple_tree = xr.DataTree.from_dict({"a": None})
631+
for _ in xr.group_subtrees(tree, simple_tree):
632+
...
633+
634+
We can explicitly also check if any two trees are isomorphic using the :py:meth:`~xarray.DataTree.isomorphic` method:
635+
636+
.. ipython:: python
637+
638+
tree.isomorphic(simple_tree)
599639
600-
dt3 = xr.DataTree.from_dict({"a": None, "b": None})
601-
dt1.isomorphic(dt3)
640+
Corresponding tree nodes do not need to have the same data in order to be considered isomorphic:
602641

603-
dt4 = xr.DataTree.from_dict({"A": None, "A/B": xr.Dataset({"foo": 1})})
604-
dt1.isomorphic(dt4)
642+
.. ipython:: python
643+
644+
tree_with_data = xr.DataTree.from_dict({"a": xr.Dataset({"foo": 1})})
645+
simple_tree.isomorphic(tree_with_data)
646+
647+
They also do not need to define child nodes in the same order:
648+
649+
.. ipython:: python
605650
606-
If the trees are not isomorphic a :py:class:`~xarray.TreeIsomorphismError` will be raised.
607-
Notice that corresponding tree nodes do not need to have the same name or contain the same data in order to be considered isomorphic.
651+
reordered_tree = xr.DataTree.from_dict({"a": None, "a/c": None, "a/b": None})
652+
tree.isomorphic(reordered_tree)
608653
609654
Arithmetic Between Multiple Trees
610655
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

doc/whats-new.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ New Features
2323
~~~~~~~~~~~~
2424
- ``DataTree`` related functionality is now exposed in the main ``xarray`` public
2525
API. This includes: ``xarray.DataTree``, ``xarray.open_datatree``, ``xarray.open_groups``,
26-
``xarray.map_over_datasets``, ``xarray.register_datatree_accessor`` and
27-
``xarray.testing.assert_isomorphic``.
26+
``xarray.map_over_datasets``, ``xarray.group_subtrees``,
27+
``xarray.register_datatree_accessor`` and ``xarray.testing.assert_isomorphic``.
2828
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_,
2929
`Eni Awowale <https://github.com/eni-awowale>`_,
3030
`Matt Savoie <https://github.com/flamingbear>`_,

xarray/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from xarray.core.dataarray import DataArray
3535
from xarray.core.dataset import Dataset
3636
from xarray.core.datatree import DataTree
37-
from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_datasets
37+
from xarray.core.datatree_mapping import map_over_datasets
3838
from xarray.core.extensions import (
3939
register_dataarray_accessor,
4040
register_dataset_accessor,
@@ -45,7 +45,12 @@
4545
from xarray.core.merge import Context, MergeError, merge
4646
from xarray.core.options import get_options, set_options
4747
from xarray.core.parallel import map_blocks
48-
from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError
48+
from xarray.core.treenode import (
49+
InvalidTreeError,
50+
NotFoundInTreeError,
51+
TreeIsomorphismError,
52+
group_subtrees,
53+
)
4954
from xarray.core.variable import IndexVariable, Variable, as_variable
5055
from xarray.namedarray.core import NamedArray
5156
from xarray.util.print_versions import show_versions
@@ -82,6 +87,7 @@
8287
"cross",
8388
"full_like",
8489
"get_options",
90+
"group_subtrees",
8591
"infer_freq",
8692
"load_dataarray",
8793
"load_dataset",

xarray/core/computation.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from xarray.core.merge import merge_attrs, merge_coordinates_without_align
3232
from xarray.core.options import OPTIONS, _get_keep_attrs
3333
from xarray.core.types import Dims, T_DataArray
34-
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set
34+
from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name
3535
from xarray.core.variable import Variable
3636
from xarray.namedarray.parallelcompat import get_chunked_array_type
3737
from xarray.namedarray.pycompat import is_chunked_array
@@ -46,7 +46,6 @@
4646
MissingCoreDimOptions = Literal["raise", "copy", "drop"]
4747

4848
_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
49-
_DEFAULT_NAME = utils.ReprObject("<default-name>")
5049
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
5150

5251

@@ -186,18 +185,6 @@ def _enumerate(dim):
186185
return str(alt_signature)
187186

188187

189-
def result_name(objects: Iterable[Any]) -> Any:
190-
# use the same naming heuristics as pandas:
191-
# https://github.com/blaze/blaze/issues/458#issuecomment-51936356
192-
names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects}
193-
names.discard(_DEFAULT_NAME)
194-
if len(names) == 1:
195-
(name,) = names
196-
else:
197-
name = None
198-
return name
199-
200-
201188
def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]:
202189
coords_list = []
203190
for arg in args:

xarray/core/dataarray.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
either_dict_or_kwargs,
7676
hashable,
7777
infix_dims,
78+
result_name,
7879
)
7980
from xarray.core.variable import (
8081
IndexVariable,
@@ -4726,15 +4727,6 @@ def identical(self, other: Self) -> bool:
47264727
except (TypeError, AttributeError):
47274728
return False
47284729

4729-
def _result_name(self, other: Any = None) -> Hashable | None:
4730-
# use the same naming heuristics as pandas:
4731-
# https://github.com/ContinuumIO/blaze/issues/458#issuecomment-51936356
4732-
other_name = getattr(other, "name", _default)
4733-
if other_name is _default or other_name == self.name:
4734-
return self.name
4735-
else:
4736-
return None
4737-
47384730
def __array_wrap__(self, obj, context=None) -> Self:
47394731
new_var = self.variable.__array_wrap__(obj, context)
47404732
return self._replace(new_var)
@@ -4782,7 +4774,7 @@ def _binary_op(
47824774
else f(other_variable_or_arraylike, self.variable)
47834775
)
47844776
coords, indexes = self.coords._merge_raw(other_coords, reflexive)
4785-
name = self._result_name(other)
4777+
name = result_name([self, other])
47864778

47874779
return self._replace(variable, coords, name, indexes=indexes)
47884780

0 commit comments

Comments
 (0)