Skip to content

Commit

Permalink
PERF: assert_frame_equal / assert_series_equal (#55971)
Browse files Browse the repository at this point in the history
* improve perf of index assertions

* whatsnew

* faster _array_equivalent_object

* add comment

* remove xfail

* skip mask if not needed
  • Loading branch information
lukemanley authored Nov 17, 2023
1 parent dbf8aaf commit 32ebcfc
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 28 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Other Deprecations

Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` for objects indexed by a :class:`MultiIndex` (:issue:`55949`)
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` (:issue:`55949`, :issue:`55971`)
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)
Expand Down
15 changes: 3 additions & 12 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
Series,
TimedeltaIndex,
)
from pandas.core.algorithms import take_nd
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
Expand Down Expand Up @@ -246,13 +245,6 @@ def _check_types(left, right, obj: str = "Index") -> None:

assert_attr_equal("dtype", left, right, obj=obj)

def _get_ilevel_values(index, level):
# accept level number only
unique = index.levels[level]
level_codes = index.codes[level]
filled = take_nd(unique._values, level_codes, fill_value=unique._na_value)
return unique._shallow_copy(filled, name=index.names[level])

# instance validation
_check_isinstance(left, right, Index)

Expand Down Expand Up @@ -299,9 +291,8 @@ def _get_ilevel_values(index, level):
)
assert_numpy_array_equal(left.codes[level], right.codes[level])
except AssertionError:
# cannot use get_level_values here because it can change dtype
llevel = _get_ilevel_values(left, level)
rlevel = _get_ilevel_values(right, level)
llevel = left.get_level_values(level)
rlevel = right.get_level_values(level)

assert_index_equal(
llevel,
Expand Down Expand Up @@ -592,7 +583,7 @@ def raise_assert_detail(
{message}"""

if isinstance(index_values, Index):
index_values = np.array(index_values)
index_values = np.asarray(index_values)

if isinstance(index_values, np.ndarray):
msg += f"\n[index]: {pprint_thing(index_values)}"
Expand Down
29 changes: 23 additions & 6 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,29 @@ def _array_equivalent_datetimelike(left: np.ndarray, right: np.ndarray):


def _array_equivalent_object(left: np.ndarray, right: np.ndarray, strict_nan: bool):
if not strict_nan:
# isna considers NaN and None to be equivalent.

return lib.array_equivalent_object(ensure_object(left), ensure_object(right))

for left_value, right_value in zip(left, right):
left = ensure_object(left)
right = ensure_object(right)

mask: npt.NDArray[np.bool_] | None = None
if strict_nan:
mask = isna(left) & isna(right)
if not mask.any():
mask = None

try:
if mask is None:
return lib.array_equivalent_object(left, right)
if not lib.array_equivalent_object(left[~mask], right[~mask]):
return False
left_remaining = left[mask]
right_remaining = right[mask]
except ValueError:
# can raise a ValueError if left and right cannot be
# compared (e.g. nested arrays)
left_remaining = left
right_remaining = right

for left_value, right_value in zip(left_remaining, right_remaining):
if left_value is NaT and right_value is not NaT:
return False

Expand Down
12 changes: 3 additions & 9 deletions pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,7 @@ def test_array_equivalent_str(dtype):
)


@pytest.mark.parametrize(
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
)
@pytest.mark.parametrize("strict_nan", [True, False])
def test_array_equivalent_nested(strict_nan):
# reached in groupby aggregations, make sure we use np.any when checking
# if the comparison is truthy
Expand All @@ -585,9 +583,7 @@ def test_array_equivalent_nested(strict_nan):


@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning")
@pytest.mark.parametrize(
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
)
@pytest.mark.parametrize("strict_nan", [True, False])
def test_array_equivalent_nested2(strict_nan):
# more than one level of nesting
left = np.array(
Expand All @@ -612,9 +608,7 @@ def test_array_equivalent_nested2(strict_nan):
assert not array_equivalent(left, right, strict_nan=strict_nan)


@pytest.mark.parametrize(
"strict_nan", [pytest.param(True, marks=pytest.mark.xfail), False]
)
@pytest.mark.parametrize("strict_nan", [True, False])
def test_array_equivalent_nested_list(strict_nan):
left = np.array([[50, 70, 90], [20, 30]], dtype=object)
right = np.array([[50, 70, 90], [20, 30]], dtype=object)
Expand Down

0 comments on commit 32ebcfc

Please sign in to comment.