Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to pass callable assertion failure message generator #5607

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=Tru


@ensure_warnings
def assert_equal(a, b):
def assert_equal(a, b, fail_func=None):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.

Expand All @@ -69,24 +69,32 @@ def assert_equal(a, b):
The first object to compare.
b : xarray.Dataset, xarray.DataArray or xarray.Variable
The second object to compare.
fail_func : callable, optional
Function that takes a, b as arguments, and returns a string to use as
an assertion failure message.

See Also
--------
assert_identical, assert_allclose, Dataset.equals, DataArray.equals
numpy.testing.assert_array_equal
"""
__tracebackhide__ = True

if fail_func is None:
def fail_func(a, b):
return formatting.diff_array_repr(a, b, "equals")

assert type(a) == type(b)
if isinstance(a, (Variable, DataArray)):
assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
assert a.equals(b), fail_func(a, b)
elif isinstance(a, Dataset):
assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals")
assert a.equals(b), fail_func(a, b)
else:
raise TypeError("{} not supported by assertion comparison".format(type(a)))


@ensure_warnings
def assert_identical(a, b):
def assert_identical(a, b, fail_func=None):
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
objects' names and attributes.

Expand All @@ -98,26 +106,34 @@ def assert_identical(a, b):
The first object to compare.
b : xarray.Dataset, xarray.DataArray or xarray.Variable
The second object to compare.
fail_func : callable, optional
Function that takes a, b as arguments, and returns a string to use as
an assertion failure message.

See Also
--------
assert_equal, assert_allclose, Dataset.equals, DataArray.equals
"""
__tracebackhide__ = True

if fail_func is None:
def fail_func(a, b):
return formatting.diff_array_repr(a, b, "indentical")

assert type(a) == type(b)
if isinstance(a, Variable):
assert a.identical(b), formatting.diff_array_repr(a, b, "identical")
assert a.identical(b), fail_func(a, b)
elif isinstance(a, DataArray):
assert a.name == b.name
assert a.identical(b), formatting.diff_array_repr(a, b, "identical")
assert a.identical(b), fail_func(a, b)
elif isinstance(a, (Dataset, Variable)):
assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical")
assert a.identical(b), fail_func(a, b)
else:
raise TypeError("{} not supported by assertion comparison".format(type(a)))


@ensure_warnings
def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, fail_func=None):
"""Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects.

Raises an AssertionError if two objects are not equal up to desired
Expand All @@ -137,6 +153,9 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
Whether byte dtypes should be decoded to strings as UTF-8 or not.
This is useful for testing serialization methods on Python 3 that
return saved strings as bytes.
fail_func : callable, optional
Function that takes a, b as arguments, and returns a string to use as
an assertion failure message.

See Also
--------
Expand All @@ -156,19 +175,23 @@ def compat_variable(a, b):

return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data))

if fail_func is None:
def fail_func(a, b):
return formatting.diff_array_repr(a, b, compat=equiv)

if isinstance(a, Variable):
allclose = compat_variable(a, b)
assert allclose, formatting.diff_array_repr(a, b, compat=equiv)
assert allclose, fail_func(a, b)
elif isinstance(a, DataArray):
allclose = utils.dict_equiv(
a.coords, b.coords, compat=compat_variable
) and compat_variable(a.variable, b.variable)
assert allclose, formatting.diff_array_repr(a, b, compat=equiv)
assert allclose, fail_func(a, b)
elif isinstance(a, Dataset):
allclose = a._coord_names == b._coord_names and utils.dict_equiv(
a.variables, b.variables, compat=compat_variable
)
assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv)
assert allclose, fail_func(a, b)
else:
raise TypeError("{} not supported by assertion comparison".format(type(a)))

Expand Down