Skip to content

Support of repr and deepcopy of recursive arrays #7112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ Deprecations

Bug fixes
~~~~~~~~~

- Support for recursively defined Arrays. Fixes repr and deepcopy. (:issue:`7111`, :pull:`7112`)
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Fixed :py:meth:`Dataset.transpose` to raise a more informative error. (:issue:`6502`, :pull:`7120`)
By `Patrick Naylor <https://github.com/patrick-naylor>`_

Expand Down
24 changes: 16 additions & 8 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def _overwrite_indexes(
new_indexes.pop(name)

if rename_dims:
new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims]
new_variable.dims = tuple(rename_dims.get(d, d) for d in new_variable.dims)

return self._replace(
variable=new_variable, coords=new_coords, indexes=new_indexes
Expand Down Expand Up @@ -1169,25 +1169,33 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray:
--------
pandas.DataFrame.copy
"""
variable = self.variable.copy(deep=deep, data=data)
return self._copy(deep=deep, data=data)

def _copy(
self: T_DataArray,
deep: bool = True,
data: Any = None,
memo: dict[int, Any] | None = None,
) -> T_DataArray:
variable = self.variable._copy(deep=deep, data=data, memo=memo)
indexes, index_vars = self.xindexes.copy_indexes(deep=deep)

coords = {}
for k, v in self._coords.items():
if k in index_vars:
coords[k] = index_vars[k]
else:
coords[k] = v.copy(deep=deep)
coords[k] = v._copy(deep=deep, memo=memo)

return self._replace(variable, coords, indexes=indexes)

def __copy__(self: T_DataArray) -> T_DataArray:
return self.copy(deep=False)
return self._copy(deep=False)

def __deepcopy__(self: T_DataArray, memo=None) -> T_DataArray:
# memo does nothing but is required for compatibility with
# copy.deepcopy
return self.copy(deep=True)
def __deepcopy__(
self: T_DataArray, memo: dict[int, Any] | None = None
) -> T_DataArray:
return self._copy(deep=True, memo=memo)

# mutable objects should not be Hashable
# https://github.com/python/mypy/issues/4266
Expand Down
30 changes: 19 additions & 11 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,14 @@ def copy(
--------
pandas.DataFrame.copy
"""
return self._copy(deep=deep, data=data)

def _copy(
self: T_Dataset,
deep: bool = False,
data: Mapping[Any, ArrayLike] | None = None,
memo: dict[int, Any] | None = None,
) -> T_Dataset:
if data is None:
data = {}
elif not utils.is_dict_like(data):
Expand Down Expand Up @@ -1249,13 +1257,21 @@ def copy(
if k in index_vars:
variables[k] = index_vars[k]
else:
variables[k] = v.copy(deep=deep, data=data.get(k))
variables[k] = v._copy(deep=deep, data=data.get(k), memo=memo)

attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding)
attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
encoding = (
copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
)

return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding)

def __copy__(self: T_Dataset) -> T_Dataset:
return self._copy(deep=False)

def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset:
return self._copy(deep=True, memo=memo)

def as_numpy(self: T_Dataset) -> T_Dataset:
"""
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
Expand Down Expand Up @@ -1332,14 +1348,6 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:

return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True)

def __copy__(self: T_Dataset) -> T_Dataset:
return self.copy(deep=False)

def __deepcopy__(self: T_Dataset, memo=None) -> T_Dataset:
# memo does nothing but is required for compatibility with
# copy.deepcopy
return self.copy(deep=True)

@property
def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for attribute-style access"""
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import defaultdict
from datetime import datetime, timedelta
from itertools import chain, zip_longest
from reprlib import recursive_repr
from typing import Collection, Hashable

import numpy as np
Expand Down Expand Up @@ -385,7 +386,6 @@ def _mapping_repr(
expand_option_name="display_expand_data_vars",
)


attrs_repr = functools.partial(
_mapping_repr,
title="Attributes",
Expand Down Expand Up @@ -551,6 +551,7 @@ def short_data_repr(array):
return f"[{array.size} values with dtype={array.dtype}]"


@recursive_repr("<recursive array>")
def array_repr(arr):
from .variable import Variable

Expand Down Expand Up @@ -592,11 +593,12 @@ def array_repr(arr):
summary.append(unindexed_dims_str)

if arr.attrs:
summary.append(attrs_repr(arr.attrs))
summary.append(attrs_repr(arr.attrs, max_rows=max_rows))

return "\n".join(summary)


@recursive_repr("<recursive Dataset>")
def dataset_repr(ds):
summary = [f"<xarray.{type(ds).__name__}>"]

Expand Down
11 changes: 4 additions & 7 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,13 @@ def equivalent(first: T, second: T) -> bool:
# TODO: refactor to avoid circular import
from . import duck_array_ops

if first is second:
return True
if isinstance(first, np.ndarray) or isinstance(second, np.ndarray):
return duck_array_ops.array_equiv(first, second)
elif isinstance(first, list) or isinstance(second, list):
if isinstance(first, list) or isinstance(second, list):
return list_equiv(first, second)
else:
return (
(first is second)
or (first == second)
or (pd.isnull(first) and pd.isnull(second))
)
return (first == second) or (pd.isnull(first) and pd.isnull(second))


def list_equiv(first, second):
Expand Down
32 changes: 22 additions & 10 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,9 @@ def encoding(self, value):
except ValueError:
raise ValueError("encoding must be castable to a dictionary")

def copy(self, deep: bool = True, data: ArrayLike | None = None):
def copy(
self: T_Variable, deep: bool = True, data: ArrayLike | None = None
) -> T_Variable:
"""Returns a copy of this object.

If `deep=True`, the data array is loaded into memory and copied onto
Expand Down Expand Up @@ -974,6 +976,14 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
--------
pandas.DataFrame.copy
"""
return self._copy(deep=deep, data=data)

def _copy(
self: T_Variable,
deep: bool = True,
data: ArrayLike | None = None,
memo: dict[int, Any] | None = None,
) -> T_Variable:
if data is None:
ndata = self._data

Expand All @@ -982,7 +992,7 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
ndata = indexing.MemoryCachedArray(ndata.array)

if deep:
ndata = copy.deepcopy(ndata)
ndata = copy.deepcopy(ndata, memo)

else:
ndata = as_compatible_data(data)
Expand All @@ -993,8 +1003,10 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
)
)

attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding)
attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
encoding = (
copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
)

# note: dims is already an immutable tuple
return self._replace(data=ndata, attrs=attrs, encoding=encoding)
Expand All @@ -1016,13 +1028,13 @@ def _replace(
encoding = copy.copy(self._encoding)
return type(self)(dims, data, attrs, encoding, fastpath=True)

def __copy__(self):
return self.copy(deep=False)
def __copy__(self: T_Variable) -> T_Variable:
return self._copy(deep=False)

def __deepcopy__(self, memo=None):
# memo does nothing but is required for compatibility with
# copy.deepcopy
return self.copy(deep=True)
def __deepcopy__(
self: T_Variable, memo: dict[int, Any] | None = None
) -> T_Variable:
return self._copy(deep=True, memo=memo)

# mutable objects should not be hashable
# https://github.com/python/mypy/issues/4266
Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def test_concat_errors(self):
concat([data, data], "new_dim", coords=["not_found"])

with pytest.raises(ValueError, match=r"global attributes not"):
data0, data1 = deepcopy(split_data)
# call deepcopy seperately to get unique attrs
data0 = deepcopy(split_data[0])
data1 = deepcopy(split_data[1])
data1.attrs["foo"] = "bar"
concat([data0, data1], "dim1", compat="identical")
assert_identical(data, concat([data0, data1], "dim1", compat="equals"))
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6488,6 +6488,28 @@ def test_deepcopy_obj_array() -> None:
assert x0.values[0] is not x1.values[0]


def test_deepcopy_recursive() -> None:
# GH:issue:7111

# direct recursion
da = xr.DataArray([1, 2], dims=["x"])
da.attrs["other"] = da

# TODO: cannot use assert_identical on recursive Vars yet...
# lets just ensure that deep copy works without RecursionError
da.copy(deep=True)

# indirect recursion
da2 = xr.DataArray([5, 6], dims=["y"])
da.attrs["other"] = da2
da2.attrs["other"] = da

# TODO: cannot use assert_identical on recursive Vars yet...
# lets just ensure that deep copy works without RecursionError
da.copy(deep=True)
da2.copy(deep=True)


def test_clip(da: DataArray) -> None:
with raise_if_dask_computes():
result = da.clip(min=0.5)
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6687,6 +6687,28 @@ def test_deepcopy_obj_array() -> None:
assert x0["foo"].values[0] is not x1["foo"].values[0]


def test_deepcopy_recursive() -> None:
# GH:issue:7111

# direct recursion
ds = xr.Dataset({"a": (["x"], [1, 2])})
ds.attrs["other"] = ds

# TODO: cannot use assert_identical on recursive Vars yet...
# lets just ensure that deep copy works without RecursionError
ds.copy(deep=True)

# indirect recursion
ds2 = xr.Dataset({"b": (["y"], [3, 4])})
ds.attrs["other"] = ds2
ds2.attrs["other"] = ds

# TODO: cannot use assert_identical on recursive Vars yet...
# lets just ensure that deep copy works without RecursionError
ds.copy(deep=True)
ds2.copy(deep=True)


def test_clip(ds) -> None:
result = ds.clip(min=0.5)
assert all((result.min(...) >= 0.5).values())
Expand Down
33 changes: 33 additions & 0 deletions xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,24 @@ def test_array_repr_variable(self) -> None:
with xr.set_options(display_expand_data=False):
formatting.array_repr(var)

def test_array_repr_recursive(self) -> None:
# GH:issue:7111

# direct recurion
var = xr.Variable("x", [0, 1])
var.attrs["x"] = var
formatting.array_repr(var)

da = xr.DataArray([0, 1], dims=["x"])
da.attrs["x"] = da
formatting.array_repr(da)

# indirect recursion
var.attrs["x"] = da
da.attrs["x"] = var
formatting.array_repr(var)
formatting.array_repr(da)

@requires_dask
def test_array_scalar_format(self) -> None:
# Test numpy scalars:
Expand Down Expand Up @@ -615,6 +633,21 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None:
assert actual == expected


def test__mapping_repr_recursive() -> None:
# GH:issue:7111

# direct recursion
ds = xr.Dataset({"a": [["x"], [1, 2, 3]]})
ds.attrs["ds"] = ds
formatting.dataset_repr(ds)

# indirect recursion
ds2 = xr.Dataset({"b": [["y"], [1, 2, 3]]})
ds.attrs["ds"] = ds2
ds2.attrs["ds"] = ds
formatting.dataset_repr(ds2)


def test__element_formatter(n_elements: int = 100) -> None:
expected = """\
Dimensions without coordinates: dim_0: 3, dim_1: 3, dim_2: 3, dim_3: 3,
Expand Down
Loading