Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument check_dims to assert_allclose to allow transposed inputs (#5733) #8991

Merged
merged 17 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ New Features
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.
- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`)
By `Ignacio Martinez Vazquez <https://github.com/ignamv>`_.
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
`create_index=False`. (:pull:`8960`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
Expand Down
29 changes: 25 additions & 4 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False):
raise TypeError(f"{type(a)} not of type DataTree")


def maybe_transpose_dims(a, b, check_dim_order: bool):
"""Helper for assert_equal/allclose/identical"""
__tracebackhide__ = True
if not isinstance(a, (Variable, DataArray, Dataset)):
ignamv marked this conversation as resolved.
Show resolved Hide resolved
return b
if not check_dim_order and set(a.dims) == set(b.dims):
# Ensure transpose won't fail if a dimension is missing
# If this is the case, the difference will be caught by the caller
return b.transpose(*a.dims)
return b


@overload
def assert_equal(a, b): ...

Expand All @@ -104,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_equal(a, b, from_root=True):
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.

Expand All @@ -127,6 +139,8 @@ def assert_equal(a, b, from_root=True):
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.

See Also
--------
Expand All @@ -137,6 +151,7 @@ def assert_equal(a, b, from_root=True):
assert (
type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates)
)
b = maybe_transpose_dims(a, b, check_dim_order)
if isinstance(a, (Variable, DataArray)):
assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
elif isinstance(a, Dataset):
Expand Down Expand Up @@ -182,6 +197,8 @@ def assert_identical(a, b, from_root=True):
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.

See Also
--------
Expand Down Expand Up @@ -213,7 +230,9 @@ def assert_identical(a, b, from_root=True):


@ensure_warnings
def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
def assert_allclose(
a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True
):
"""Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects.

Raises an AssertionError if two objects are not equal up to desired
Expand All @@ -233,23 +252,25 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
Whether byte dtypes should be decoded to strings as UTF-8 or not.
This is useful for testing serialization methods on Python 3 that
return saved strings as bytes.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.

See Also
--------
assert_identical, assert_equal, numpy.testing.assert_allclose
"""
__tracebackhide__ = True
assert type(a) == type(b)
b = maybe_transpose_dims(a, b, check_dim_order)

equiv = functools.partial(
_data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes
)
equiv.__name__ = "allclose"
equiv.__name__ = "allclose" # type: ignore[attr-defined]

def compat_variable(a, b):
a = getattr(a, "variable", a)
b = getattr(b, "variable", b)

return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data))

if isinstance(a, Variable):
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ def test_allclose_regression() -> None:
def test_assert_allclose(obj1, obj2) -> None:
with pytest.raises(AssertionError):
xr.testing.assert_allclose(obj1, obj2)
with pytest.raises(AssertionError):
xr.testing.assert_allclose(obj1, obj2, check_dim_order=False)


@pytest.mark.parametrize("func", ["assert_equal", "assert_allclose"])
def test_assert_allclose_equal_transpose(func) -> None:
"""Transposed DataArray raises assertion unless check_dim_order=False."""
obj1 = xr.DataArray([[0, 1, 2], [2, 3, 4]], dims=["a", "b"])
obj2 = xr.DataArray([[0, 2], [1, 3], [2, 4]], dims=["b", "a"])
with pytest.raises(AssertionError):
getattr(xr.testing, func)(obj1, obj2)
getattr(xr.testing, func)(obj1, obj2, check_dim_order=False)
ds1 = obj1.to_dataset(name="varname")
ds1["var2"] = obj1
ds2 = obj1.to_dataset(name="varname")
ds2["var2"] = obj1.transpose()
with pytest.raises(AssertionError):
getattr(xr.testing, func)(ds1, ds2)
getattr(xr.testing, func)(ds1, ds2, check_dim_order=False)


@pytest.mark.filterwarnings("error")
Expand Down
Loading