Skip to content

Commit

Permalink
fix(python): Ensure assert_frame_not_equal and `assert_series_not_e…
Browse files Browse the repository at this point in the history
…qual` raise on mismatched input types (pola-rs#18402)
  • Loading branch information
alexander-beedie authored and r-brink committed Aug 29, 2024
1 parent 9009e6b commit c660bde
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 47 deletions.
3 changes: 2 additions & 1 deletion py-polars/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def assert_frame_not_equal(
"""
__tracebackhide__ = True

_assert_correct_input_type(left, right)
try:
assert_frame_equal(
left=left,
Expand All @@ -272,5 +273,5 @@ def assert_frame_not_equal(
except AssertionError:
return
else:
msg = "frames are equal"
msg = "frames are equal (but are expected not to be)"
raise AssertionError(msg)
26 changes: 17 additions & 9 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from polars._utils.deprecation import deprecate_renamed_parameter
from polars.datatypes import (
Expand All @@ -20,6 +20,19 @@
from polars import DataType


def _assert_correct_input_type(left: Any, right: Any) -> bool:
__tracebackhide__ = True

if not (isinstance(left, Series) and isinstance(right, Series)):
raise_assertion_error(
"inputs",
"unexpected input types",
type(left).__name__,
type(right).__name__,
)
return True


@deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31")
def assert_series_equal(
left: Series,
Expand Down Expand Up @@ -90,13 +103,7 @@ def assert_series_equal(
"""
__tracebackhide__ = True

if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr]
raise_assertion_error(
"inputs",
"unexpected input types",
type(left).__name__,
type(right).__name__,
)
_assert_correct_input_type(left, right)

if left.len() != right.len():
raise_assertion_error("Series", "length mismatch", left.len(), right.len())
Expand Down Expand Up @@ -404,6 +411,7 @@ def assert_series_not_equal(
"""
__tracebackhide__ = True

_assert_correct_input_type(left, right)
try:
assert_series_equal(
left=left,
Expand All @@ -419,5 +427,5 @@ def assert_series_not_equal(
except AssertionError:
return
else:
msg = "Series are equal"
msg = "Series are equal (but are expected not to be)"
raise AssertionError(msg)
33 changes: 26 additions & 7 deletions py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,17 @@ def test_assert_frame_equal_pass() -> None:
assert_frame_equal(df1, df2)


def test_assert_frame_equal_types() -> None:
@pytest.mark.parametrize(
"assert_function",
[assert_frame_equal, assert_frame_not_equal],
)
def test_assert_frame_equal_types(assert_function: Any) -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(
AssertionError, match=r"inputs are different \(unexpected input types\)"
):
assert_frame_equal(df1, srs1) # type: ignore[arg-type]
assert_function(df1, srs1)


def test_assert_frame_equal_length_mismatch() -> None:
Expand All @@ -295,6 +299,7 @@ def test_assert_frame_equal_length_mismatch() -> None:
match=r"DataFrames are different \(number of rows does not match\)",
):
assert_frame_equal(df1, df2)
assert_frame_not_equal(df1, df2)


def test_assert_frame_equal_column_mismatch() -> None:
Expand All @@ -304,6 +309,7 @@ def test_assert_frame_equal_column_mismatch() -> None:
AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right"
):
assert_frame_equal(df1, df2)
assert_frame_not_equal(df1, df2)


def test_assert_frame_equal_column_mismatch2() -> None:
Expand All @@ -314,6 +320,7 @@ def test_assert_frame_equal_column_mismatch2() -> None:
match="columns \\['b', 'c'\\] in right LazyFrame, but not in left",
):
assert_frame_equal(df1, df2)
assert_frame_not_equal(df1, df2)


def test_assert_frame_equal_column_mismatch_order() -> None:
Expand All @@ -323,6 +330,7 @@ def test_assert_frame_equal_column_mismatch_order() -> None:
assert_frame_equal(df1, df2)

assert_frame_equal(df1, df2, check_column_order=False)
assert_frame_not_equal(df1, df2)


def test_assert_frame_equal_check_row_order() -> None:
Expand All @@ -331,25 +339,33 @@ def test_assert_frame_equal_check_row_order() -> None:

with pytest.raises(AssertionError, match="value mismatch for column 'a'"):
assert_frame_equal(df1, df2)

assert_frame_equal(df1, df2, check_row_order=False)
assert_frame_not_equal(df1, df2)


def test_assert_frame_equal_check_row_col_order() -> None:
df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})
df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})
df2 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})

with pytest.raises(AssertionError, match="columns are not in the same order"):
assert_frame_equal(df1, df3, check_row_order=False)
assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False)
assert_frame_equal(df1, df2, check_row_order=False)

assert_frame_equal(df1, df2, check_row_order=False, check_column_order=False)
assert_frame_not_equal(df1, df2)


def test_assert_frame_equal_check_row_order_unsortable() -> None:
@pytest.mark.parametrize(
"assert_function",
[assert_frame_equal, assert_frame_not_equal],
)
def test_assert_frame_equal_check_row_order_unsortable(assert_function: Any) -> None:
df1 = pl.DataFrame({"a": [object(), object()], "b": [3, 4]})
df2 = pl.DataFrame({"a": [object(), object()], "b": [4, 3]})
with pytest.raises(
TypeError, match="cannot set `check_row_order=False`.*unsortable columns"
):
assert_frame_equal(df1, df2, check_row_order=False)
assert_function(df1, df2, check_row_order=False)


def test_assert_frame_equal_dtypes_mismatch() -> None:
Expand All @@ -360,6 +376,9 @@ def test_assert_frame_equal_dtypes_mismatch() -> None:
with pytest.raises(AssertionError, match="dtypes do not match"):
assert_frame_equal(df1, df2, check_column_order=False)

assert_frame_not_equal(df1, df2, check_column_order=False)
assert_frame_not_equal(df1, df2)


def test_assert_frame_not_equal() -> None:
df = pl.DataFrame({"a": [1, 2]})
Expand Down
Loading

0 comments on commit c660bde

Please sign in to comment.