Skip to content

Commit

Permalink
Backport PR pandas-dev#56724: TST: Don't ignore tolerance for integer…
Browse files Browse the repository at this point in the history
… series
  • Loading branch information
phofl authored and meeseeksmachine committed Jan 8, 2024
1 parent 41f22b3 commit 4b6067b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 26 deletions.
98 changes: 72 additions & 26 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np

from pandas._libs import lib
from pandas._libs.missing import is_matching_na
from pandas._libs.sparse import SparseIndex
import pandas._libs.testing as _testing
Expand Down Expand Up @@ -698,9 +699,9 @@ def assert_extension_array_equal(
right,
check_dtype: bool | Literal["equiv"] = True,
index_values=None,
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
check_exact: bool | lib.NoDefault = lib.no_default,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
obj: str = "ExtensionArray",
) -> None:
"""
Expand All @@ -715,7 +716,12 @@ def assert_extension_array_equal(
index_values : Index | numpy.ndarray, default None
Optional index (shared by both left and right), used in output.
check_exact : bool, default False
Whether to compare number exactly. Only takes effect for float dtypes.
Whether to compare number exactly.
.. versionchanged:: 2.2.0
Defaults to True for integer dtypes if none of
``check_exact``, ``rtol`` and ``atol`` are specified.
rtol : float, default 1e-5
Relative tolerance. Only used when check_exact is False.
atol : float, default 1e-8
Expand All @@ -739,6 +745,23 @@ def assert_extension_array_equal(
>>> b, c = a.array, a.array
>>> tm.assert_extension_array_equal(b, c)
"""
if (
check_exact is lib.no_default
and rtol is lib.no_default
and atol is lib.no_default
):
check_exact = (
is_numeric_dtype(left.dtype)
and not is_float_dtype(left.dtype)
or is_numeric_dtype(right.dtype)
and not is_float_dtype(right.dtype)
)
elif check_exact is lib.no_default:
check_exact = False

rtol = rtol if rtol is not lib.no_default else 1.0e-5
atol = atol if atol is not lib.no_default else 1.0e-8

assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
if check_dtype:
Expand Down Expand Up @@ -784,10 +807,7 @@ def assert_extension_array_equal(

left_valid = left[~left_na].to_numpy(dtype=object)
right_valid = right[~right_na].to_numpy(dtype=object)
if check_exact or (
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
):
if check_exact:
assert_numpy_array_equal(
left_valid, right_valid, obj=obj, index_values=index_values
)
Expand All @@ -811,14 +831,14 @@ def assert_series_equal(
check_index_type: bool | Literal["equiv"] = "equiv",
check_series_type: bool = True,
check_names: bool = True,
check_exact: bool = False,
check_exact: bool | lib.NoDefault = lib.no_default,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_category_order: bool = True,
check_freq: bool = True,
check_flags: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
obj: str = "Series",
*,
check_index: bool = True,
Expand All @@ -841,7 +861,12 @@ def assert_series_equal(
check_names : bool, default True
Whether to check the Series and Index names attribute.
check_exact : bool, default False
Whether to compare number exactly. Only takes effect for float dtypes.
Whether to compare number exactly.
.. versionchanged:: 2.2.0
Defaults to True for integer dtypes if none of
``check_exact``, ``rtol`` and ``atol`` are specified.
check_datetimelike_compat : bool, default False
Compare datetime-like which is comparable ignoring dtype.
check_categorical : bool, default True
Expand Down Expand Up @@ -877,6 +902,22 @@ def assert_series_equal(
>>> tm.assert_series_equal(a, b)
"""
__tracebackhide__ = True
if (
check_exact is lib.no_default
and rtol is lib.no_default
and atol is lib.no_default
):
check_exact = (
is_numeric_dtype(left.dtype)
and not is_float_dtype(left.dtype)
or is_numeric_dtype(right.dtype)
and not is_float_dtype(right.dtype)
)
elif check_exact is lib.no_default:
check_exact = False

rtol = rtol if rtol is not lib.no_default else 1.0e-5
atol = atol if atol is not lib.no_default else 1.0e-8

if not check_index and check_like:
raise ValueError("check_like must be False if check_index is False")
Expand Down Expand Up @@ -931,10 +972,7 @@ def assert_series_equal(
pass
else:
assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
if check_exact or (
(is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype))
or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype))
):
if check_exact:
left_values = left._values
right_values = right._values
# Only check exact if dtype is numeric
Expand Down Expand Up @@ -1061,14 +1099,14 @@ def assert_frame_equal(
check_frame_type: bool = True,
check_names: bool = True,
by_blocks: bool = False,
check_exact: bool = False,
check_exact: bool | lib.NoDefault = lib.no_default,
check_datetimelike_compat: bool = False,
check_categorical: bool = True,
check_like: bool = False,
check_freq: bool = True,
check_flags: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
rtol: float | lib.NoDefault = lib.no_default,
atol: float | lib.NoDefault = lib.no_default,
obj: str = "DataFrame",
) -> None:
"""
Expand Down Expand Up @@ -1103,7 +1141,12 @@ def assert_frame_equal(
Specify how to compare internal data. If False, compare by columns.
If True, compare by blocks.
check_exact : bool, default False
Whether to compare number exactly. Only takes effect for float dtypes.
Whether to compare number exactly.
.. versionchanged:: 2.2.0
Defaults to True for integer dtypes if none of
``check_exact``, ``rtol`` and ``atol`` are specified.
check_datetimelike_compat : bool, default False
Compare datetime-like which is comparable ignoring dtype.
check_categorical : bool, default True
Expand Down Expand Up @@ -1158,6 +1201,9 @@ def assert_frame_equal(
>>> assert_frame_equal(df1, df2, check_dtype=False)
"""
__tracebackhide__ = True
_rtol = rtol if rtol is not lib.no_default else 1.0e-5
_atol = atol if atol is not lib.no_default else 1.0e-8
_check_exact = check_exact if check_exact is not lib.no_default else False

# instance validation
_check_isinstance(left, right, DataFrame)
Expand All @@ -1181,11 +1227,11 @@ def assert_frame_equal(
right.index,
exact=check_index_type,
check_names=check_names,
check_exact=check_exact,
check_exact=_check_exact,
check_categorical=check_categorical,
check_order=not check_like,
rtol=rtol,
atol=atol,
rtol=_rtol,
atol=_atol,
obj=f"{obj}.index",
)

Expand All @@ -1195,11 +1241,11 @@ def assert_frame_equal(
right.columns,
exact=check_column_type,
check_names=check_names,
check_exact=check_exact,
check_exact=_check_exact,
check_categorical=check_categorical,
check_order=not check_like,
rtol=rtol,
atol=atol,
rtol=_rtol,
atol=_atol,
obj=f"{obj}.columns",
)

Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/util/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,15 @@ def test_ea_and_numpy_no_dtype_check(val, check_exact, dtype):
left = Series([1, 2, val], dtype=dtype)
right = Series(pd.array([1, 2, val]))
tm.assert_series_equal(left, right, check_dtype=False, check_exact=check_exact)


def test_assert_series_equal_int_tol():
# GH#56646
left = Series([81, 18, 121, 38, 74, 72, 81, 81, 146, 81, 81, 170, 74, 74])
right = Series([72, 9, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72, 72])
tm.assert_series_equal(left, right, rtol=1.5)

tm.assert_frame_equal(left.to_frame(), right.to_frame(), rtol=1.5)
tm.assert_extension_array_equal(
left.astype("Int64").values, right.astype("Int64").values, rtol=1.5
)

0 comments on commit 4b6067b

Please sign in to comment.