Skip to content

DataTree.assign_coords and similar should only set node coords #9472

Open
@TomNicholas

Description

@TomNicholas

What is your issue?

Currently DataTree.assign_coords tries to call Dataset.assign_coords on every node in the subtree. This was always a little weird, but now that we have implemented coordinate inheritance (#9077) it's really silly, because we only need to assign to the root node and the coordinates will be accessible on every child.

There's also some kind of bug right now too:

In [6]: ds = Dataset(
   ...:     data_vars={
   ...:         "foo": (["x", "y"], np.random.randn(2, 3)),
   ...:     },
   ...:     coords={
   ...:         "x": ("x", np.array([-1, -2], "int64")),
   ...:         "y": ("y", np.array([0, 1, 2], "int64")),
   ...:         "a": ("x", np.array([4, 5], "int64")),
   ...:         "b": np.int64(-10),
   ...:     },
   ...: )
   ...: dt = DataTree(data=ds)
   ...: dt["child"] = DataTree()

In [7]: child = dt["child"]

In [8]: child.assign_coords({'c': 11})
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 1
----> 1 child.assign_coords({'c': 11})

File ~/Documents/Work/Code/xarray/xarray/core/datatree_mapping.py:193, in map_over_subtree.<locals>._map_over_subtree(*args, **kwargs)
    190     out_data_objects[node_of_first_tree.path] = results
    192 # Find out how many return values we received
--> 193 num_return_values = _check_all_return_values(out_data_objects)
    195 # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
    196 original_root_path = first_tree.path

File ~/Documents/Work/Code/xarray/xarray/core/datatree_mapping.py:282, in _check_all_return_values(returned_objects)
    279 """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
    281 if all(r is None for r in returned_objects.values()):
--> 282     raise TypeError(
    283         "Called supplied function on all nodes but found a return value of None for"
    284         "all of them."
    285     )
    287 result_data_objects = [
    288     (path_to_node, r)
    289     for path_to_node, r in returned_objects.items()
    290     if r is not None
    291 ]
    293 if len(result_data_objects) == 1:
    294     # Only one node in the tree: no need to check consistency of results between nodes

TypeError: Called supplied function on all nodes but found a return value of None forall of them.

We should change assign_coords (and possibly any similar functions like set_coords to just operate on one node, and not try to map over the subtree.

This would also help make the coordinate-related behaviour of DataTree more similar to that of Dataset, which relates to #9203 (comment).

cc @shoyer

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions