Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF: assert_frame_equal / assert_series_equal #55971

Merged
merged 9 commits into from
Nov 17, 2023
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ Other Deprecations

Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` for objects indexed by a :class:`MultiIndex` (:issue:`55949`)
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` (:issue:`55949`, :issue:`55971`)
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)
Expand Down
15 changes: 3 additions & 12 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
Series,
TimedeltaIndex,
)
from pandas.core.algorithms import take_nd
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
Expand Down Expand Up @@ -246,13 +245,6 @@ def _check_types(left, right, obj: str = "Index") -> None:

assert_attr_equal("dtype", left, right, obj=obj)

def _get_ilevel_values(index, level):
# accept level number only
unique = index.levels[level]
level_codes = index.codes[level]
filled = take_nd(unique._values, level_codes, fill_value=unique._na_value)
return unique._shallow_copy(filled, name=index.names[level])

# instance validation
_check_isinstance(left, right, Index)

Expand Down Expand Up @@ -299,9 +291,8 @@ def _get_ilevel_values(index, level):
)
assert_numpy_array_equal(left.codes[level], right.codes[level])
except AssertionError:
# cannot use get_level_values here because it can change dtype
llevel = _get_ilevel_values(left, level)
rlevel = _get_ilevel_values(right, level)
llevel = left.get_level_values(level)
rlevel = right.get_level_values(level)

assert_index_equal(
llevel,
Expand Down Expand Up @@ -592,7 +583,7 @@ def raise_assert_detail(
{message}"""

if isinstance(index_values, Index):
index_values = np.array(index_values)
index_values = np.asarray(index_values)

if isinstance(index_values, np.ndarray):
msg += f"\n[index]: {pprint_thing(index_values)}"
Expand Down
17 changes: 16 additions & 1 deletion pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,22 @@ def _array_equivalent_object(left: np.ndarray, right: np.ndarray, strict_nan: bo

return lib.array_equivalent_object(ensure_object(left), ensure_object(right))

for left_value, right_value in zip(left, right):
mask = isna(left) & isna(right)
try:
if not lib.array_equivalent_object(
ensure_object(left[~mask]),
mroeschke marked this conversation as resolved.
Show resolved Hide resolved
ensure_object(right[~mask]),
):
return False
left_remaining = left[mask]
right_remaining = right[mask]
except ValueError:
# can raise a ValueError if left and right cannot be
# compared (e.g. nested arrays)
left_remaining = left
right_remaining = right

for left_value, right_value in zip(left_remaining, right_remaining):
if left_value is NaT and right_value is not NaT:
return False

Expand Down
12 changes: 3 additions & 9 deletions pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,7 @@ def test_array_equivalent_str(dtype):
)


@pytest.mark.parametrize(
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
)
@pytest.mark.parametrize("strict_nan", [True, False])
def test_array_equivalent_nested(strict_nan):
# reached in groupby aggregations, make sure we use np.any when checking
# if the comparison is truthy
Expand All @@ -585,9 +583,7 @@ def test_array_equivalent_nested(strict_nan):


@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning")
@pytest.mark.parametrize(
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
)
@pytest.mark.parametrize("strict_nan", [True, False])
def test_array_equivalent_nested2(strict_nan):
# more than one level of nesting
left = np.array(
Expand All @@ -612,9 +608,7 @@ def test_array_equivalent_nested2(strict_nan):
assert not array_equivalent(left, right, strict_nan=strict_nan)


@pytest.mark.parametrize(
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
)
@pytest.mark.parametrize("strict_nan", [True, False])
def test_array_equivalent_nested_list(strict_nan):
left = np.array([[50, 70, 90], [20, 30]], dtype=object)
right = np.array([[50, 70, 90], [20, 30]], dtype=object)
Expand Down
Loading