From a6bacfe1fe3e4ce1e13367130d41009be441ae67 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 10 Sep 2024 16:36:07 -0700 Subject: [PATCH] Update DataTree repr to indicate inheritance (#9470) * Update DataTree repr to indicate inheritance Fixes https://github.com/pydata/xarray/issues/9463 * fix whitespace * add more repr tests, fix failure * fix failure on windows * fix repr for inherited dimensions --- asv_bench/benchmarks/datatree.py | 2 +- xarray/core/formatting.py | 140 ++++++++++++++++++++++++------- xarray/tests/test_datatree.py | 125 ++++++++++++++++++++++++--- xarray/tests/test_formatting.py | 5 +- 4 files changed, 223 insertions(+), 49 deletions(-) diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py index 13eedd0a518..9f1774f60ac 100644 --- a/asv_bench/benchmarks/datatree.py +++ b/asv_bench/benchmarks/datatree.py @@ -6,7 +6,7 @@ class Datatree: def setup(self): run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) self.d_few = {"run1": run1} - self.d_many = {f"run{i}": run1.copy() for i in range(100)} + self.d_many = {f"run{i}": xr.Dataset({"a": 1}) for i in range(100)} def time_from_dict_few(self): DataTree.from_dict(self.d_few) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 110f80f8f5f..d17a72aac27 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -6,13 +6,13 @@ import contextlib import functools import math -from collections import defaultdict -from collections.abc import Collection, Hashable, Sequence +from collections import ChainMap, defaultdict +from collections.abc import Collection, Hashable, Mapping, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr from textwrap import dedent -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -29,6 +29,7 @@ if TYPE_CHECKING: from xarray.core.coordinates import AbstractCoordinates from xarray.core.datatree import DataTree + from xarray.core.variable import Variable UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") @@ -318,7 +319,7 @@ def inline_variable_array_repr(var, max_width): def summarize_variable( name: Hashable, - var, + var: Variable, col_width: int, max_width: int | None = None, is_index: bool = False, @@ -446,6 +447,21 @@ def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None): ) +def inherited_coords_repr(node: DataTree, col_width=None, max_rows=None): + coords = _inherited_vars(node._coord_variables) + if col_width is None: + col_width = _calculate_col_width(coords) + return _mapping_repr( + coords, + title="Inherited coordinates", + summarizer=summarize_variable, + expand_option_name="display_expand_coords", + col_width=col_width, + indexes=node._indexes, + max_rows=max_rows, + ) + + def inline_index_repr(index: pd.Index, max_width=None): if hasattr(index, "_repr_inline_"): repr_ = index._repr_inline_(max_width=max_width) @@ -498,12 +514,12 @@ def filter_nondefault_indexes(indexes, filter_indexes: bool): } -def indexes_repr(indexes, max_rows: int | None = None) -> str: +def indexes_repr(indexes, max_rows: int | None = None, title: str = "Indexes") -> str: col_width = _calculate_col_width(chain.from_iterable(indexes)) return _mapping_repr( indexes, - "Indexes", + title, summarize_index, "display_expand_indexes", col_width=col_width, @@ -571,8 +587,10 @@ def _element_formatter( return "".join(out) -def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str: - elements = [f"{k}: {v}" for k, v in obj.sizes.items()] +def dim_summary_limited( + sizes: Mapping[Any, int], col_width: int, max_rows: int | None = None +) -> str: + elements = [f"{k}: {v}" for k, v in sizes.items()] return _element_formatter(elements, col_width, max_rows) @@ -676,7 +694,7 @@ def array_repr(arr): data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"]) start = f" Size: {nbytes_str}", @@ -721,7 +739,9 @@ def dataset_repr(ds): max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) - dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + dims_values = dim_summary_limited( + ds.sizes, col_width=col_width + 1, max_rows=max_rows + ) summary.append(f"{dims_start}({dims_values})") if ds.coords: @@ -756,7 +776,9 @@ def dims_and_coords_repr(ds) -> str: max_rows = OPTIONS["display_max_rows"] dims_start = pretty_print("Dimensions:", col_width) - dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + dims_values = dim_summary_limited( + ds.sizes, col_width=col_width + 1, max_rows=max_rows + ) summary.append(f"{dims_start}({dims_values})") if ds.coords: @@ -1050,19 +1072,73 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): return "\n".join(summary) -def _single_node_repr(node: DataTree) -> str: - """Information about this node, not including its relationships to other nodes.""" - if node.has_data or node.has_attrs: - # TODO: change this to inherited=False, in order to clarify what is - # inherited? https://github.com/pydata/xarray/issues/9463 - node_view = node._to_dataset_view(rebuild_dims=False, inherited=True) - ds_info = "\n" + repr(node_view) - else: - ds_info = "" - return f"Group: {node.path}{ds_info}" +def _inherited_vars(mapping: ChainMap) -> dict: + return {k: v for k, v in mapping.parents.items() if k not in mapping.maps[0]} + + +def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str: + summary = [f"Group: {node.path}"] + + col_width = _calculate_col_width(node.variables) + max_rows = OPTIONS["display_max_rows"] + + inherited_coords = _inherited_vars(node._coord_variables) + + # Only show dimensions if also showing a variable or coordinates section. + show_dims = ( + node._node_coord_variables + or (show_inherited and inherited_coords) + or node._data_variables + ) + + dim_sizes = node.sizes if show_inherited else node._node_dims + + if show_dims: + # Includes inherited dimensions. + dims_start = pretty_print("Dimensions:", col_width) + dims_values = dim_summary_limited( + dim_sizes, col_width=col_width + 1, max_rows=max_rows + ) + summary.append(f"{dims_start}({dims_values})") + + if node._node_coord_variables: + summary.append(coords_repr(node.coords, col_width=col_width, max_rows=max_rows)) + + if show_inherited and inherited_coords: + summary.append( + inherited_coords_repr(node, col_width=col_width, max_rows=max_rows) + ) + + if show_dims: + unindexed_dims_str = unindexed_dims_repr( + dim_sizes, node.coords, max_rows=max_rows + ) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + if node._data_variables: + summary.append( + data_vars_repr(node._data_variables, col_width=col_width, max_rows=max_rows) + ) + # TODO: only show indexes defined at this node, with a separate section for + # inherited indexes (if show_inherited=True) + display_default_indexes = _get_boolean_with_default( + "display_default_indexes", False + ) + xindexes = filter_nondefault_indexes( + _get_indexes_dict(node.xindexes), not display_default_indexes + ) + if xindexes: + summary.append(indexes_repr(xindexes, max_rows=max_rows)) -def datatree_repr(dt: DataTree): + if node.attrs: + summary.append(attrs_repr(node.attrs, max_rows=max_rows)) + + return "\n".join(summary) + + +def datatree_repr(dt: DataTree) -> str: """A printable representation of the structure of this entire tree.""" renderer = RenderDataTree(dt) @@ -1070,19 +1146,21 @@ def datatree_repr(dt: DataTree): header = f"" lines = [header] + show_inherited = True for pre, fill, node in renderer: - node_repr = _single_node_repr(node) + node_repr = _datatree_node_repr(node, show_inherited=show_inherited) + show_inherited = False # only show inherited coords on the root + + raw_repr_lines = node_repr.splitlines() - node_line = f"{pre}{node_repr.splitlines()[0]}" + node_line = f"{pre}{raw_repr_lines[0]}" lines.append(node_line) - if node.has_data or node.has_attrs: - ds_repr = node_repr.splitlines()[2:] - for line in ds_repr: - if len(node.children) > 0: - lines.append(f"{fill}{renderer.style.vertical}{line}") - else: - lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") + for line in raw_repr_lines[1:]: + if len(node.children) > 0: + lines.append(f"{fill}{renderer.style.vertical}{line}") + else: + lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}") return "\n".join(lines) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 6d208a5cf98..353f1fab708 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -777,7 +777,8 @@ def test_operation_with_attrs_but_no_data(self): class TestRepr: - def test_repr(self): + + def test_repr_four_nodes(self): dt = DataTree.from_dict( { "/": xr.Dataset( @@ -801,18 +802,13 @@ def test_repr(self): │ Data variables: │ e (x) float64 16B 1.0 2.0 └── Group: /b - │ Dimensions: (x: 2, y: 1) - │ Coordinates: - │ * x (x) float64 16B 2.0 3.0 + │ Dimensions: (y: 1) │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: (x: 2, y: 1) - Coordinates: - * x (x) float64 16B 2.0 3.0 - Dimensions without coordinates: y + Dimensions: () Data variables: g float64 8B 4.0 """ @@ -825,23 +821,126 @@ def test_repr(self): Group: /b │ Dimensions: (x: 2, y: 1) - │ Coordinates: + │ Inherited coordinates: │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d - Dimensions: (x: 2, y: 1) - Coordinates: - * x (x) float64 16B 2.0 3.0 - Dimensions without coordinates: y + Dimensions: () Data variables: g float64 8B 4.0 """ ).strip() assert result == expected + result = repr(dt.b.d) + expected = dedent( + """ + + Group: /b/d + Dimensions: (x: 2, y: 1) + Inherited coordinates: + * x (x) float64 16B 2.0 3.0 + Dimensions without coordinates: y + Data variables: + g float64 8B 4.0 + """ + ).strip() + assert result == expected + + def test_repr_two_children(self): + tree = DataTree.from_dict( + { + "/": Dataset(coords={"x": [1.0]}), + "/first_child": None, + "/second_child": Dataset({"foo": ("x", [0.0])}), + } + ) + + result = repr(tree) + expected = dedent( + """ + + Group: / + │ Dimensions: (x: 1) + │ Coordinates: + │ * x (x) float64 8B 1.0 + ├── Group: /first_child + └── Group: /second_child + Dimensions: (x: 1) + Data variables: + foo (x) float64 8B 0.0 + """ + ).strip() + assert result == expected + + result = repr(tree["first_child"]) + expected = dedent( + """ + + Group: /first_child + Dimensions: (x: 1) + Inherited coordinates: + * x (x) float64 8B 1.0 + """ + ).strip() + assert result == expected + + result = repr(tree["second_child"]) + expected = dedent( + """ + + Group: /second_child + Dimensions: (x: 1) + Inherited coordinates: + * x (x) float64 8B 1.0 + Data variables: + foo (x) float64 8B 0.0 + """ + ).strip() + assert result == expected + + def test_repr_inherited_dims(self): + tree = DataTree.from_dict( + { + "/": Dataset({"foo": ("x", [1.0])}), + "/child": Dataset({"bar": ("y", [2.0])}), + } + ) + + result = repr(tree) + expected = dedent( + """ + + Group: / + │ Dimensions: (x: 1) + │ Dimensions without coordinates: x + │ Data variables: + │ foo (x) float64 8B 1.0 + └── Group: /child + Dimensions: (y: 1) + Dimensions without coordinates: y + Data variables: + bar (y) float64 8B 2.0 + """ + ).strip() + assert result == expected + + result = repr(tree["child"]) + expected = dedent( + """ + + Group: /child + Dimensions: (x: 1, y: 1) + Dimensions without coordinates: x, y + Data variables: + bar (y) float64 8B 2.0 + """ + ).strip() + assert result == expected + def _exact_match(message: str) -> str: return re.escape(dedent(message).strip()) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index a2fef9d9b6b..6066eeacfee 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -632,9 +632,6 @@ def test_datatree_print_empty_node_with_attrs(self): """\ Group: / - Dimensions: () - Data variables: - *empty* Attributes: note: has attrs""" ) @@ -888,7 +885,7 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: col_width = formatting._calculate_col_width(ds.variables) dims_start = formatting.pretty_print("Dimensions:", col_width) dims_values = formatting.dim_summary_limited( - ds, col_width=col_width + 1, max_rows=display_max_rows + ds.sizes, col_width=col_width + 1, max_rows=display_max_rows ) expected_size = "1kB" expected = f"""\