Skip to content

Commit

Permalink
REF/TST: handle boolean dtypes in base extension tests (#54334)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Aug 1, 2023
1 parent b91d7f0 commit 2bb3557
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 214 deletions.
35 changes: 32 additions & 3 deletions pandas/tests/extension/base/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,25 @@ def test_grouping_grouper(self, data_for_grouping):
@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})

is_bool = data_for_grouping.dtype._is_boolean
if is_bool:
# only 2 unique values, and the final entry has c==b
# (see data_for_grouping docstring)
df = df.iloc[:-1]

result = df.groupby("B", as_index=as_index).A.mean()
_, uniques = pd.factorize(data_for_grouping, sort=True)

exp_vals = [3.0, 1.0, 4.0]
if is_bool:
exp_vals = exp_vals[:-1]
if as_index:
index = pd.Index(uniques, name="B")
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
expected = pd.Series(exp_vals, index=index, name="A")
self.assert_series_equal(result, expected)
else:
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0, 4.0]})
expected = pd.DataFrame({"B": uniques, "A": exp_vals})
self.assert_frame_equal(result, expected)

def test_groupby_agg_extension(self, data_for_grouping):
Expand Down Expand Up @@ -83,19 +93,38 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(self):

def test_groupby_extension_no_sort(self, data_for_grouping):
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})

is_bool = data_for_grouping.dtype._is_boolean
if is_bool:
# only 2 unique values, and the final entry has c==b
# (see data_for_grouping docstring)
df = df.iloc[:-1]

result = df.groupby("B", sort=False).A.mean()
_, index = pd.factorize(data_for_grouping, sort=False)

index = pd.Index(index, name="B")
expected = pd.Series([1.0, 3.0, 4.0], index=index, name="A")
exp_vals = [1.0, 3.0, 4.0]
if is_bool:
exp_vals = exp_vals[:-1]
expected = pd.Series(exp_vals, index=index, name="A")
self.assert_series_equal(result, expected)

def test_groupby_extension_transform(self, data_for_grouping):
is_bool = data_for_grouping.dtype._is_boolean

valid = data_for_grouping[~data_for_grouping.isna()]
df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4], "B": valid})
is_bool = data_for_grouping.dtype._is_boolean
if is_bool:
# only 2 unique values, and the final entry has c==b
# (see data_for_grouping docstring)
df = df.iloc[:-1]

result = df.groupby("B").A.transform(len)
expected = pd.Series([3, 3, 2, 2, 3, 1], name="A")
if is_bool:
expected = expected[:-1]

self.assert_series_equal(result, expected)

Expand Down
52 changes: 48 additions & 4 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,22 @@ def test_argsort_missing(self, data_missing_for_sorting):

def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value):
# GH 24382
is_bool = data_for_sorting.dtype._is_boolean

exp_argmax = 1
exp_argmax_repeated = 3
if is_bool:
# See data_for_sorting docstring
exp_argmax = 0
exp_argmax_repeated = 1

# data_for_sorting -> [B, C, A] with A < B < C
assert data_for_sorting.argmax() == 1
assert data_for_sorting.argmax() == exp_argmax
assert data_for_sorting.argmin() == 2

# with repeated values -> first occurrence
data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
assert data.argmax() == 3
assert data.argmax() == exp_argmax_repeated
assert data.argmin() == 0

# with missing values
Expand Down Expand Up @@ -244,8 +252,15 @@ def test_unique(self, data, box, method):

def test_factorize(self, data_for_grouping):
codes, uniques = pd.factorize(data_for_grouping, use_na_sentinel=True)
expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 2], dtype=np.intp)
expected_uniques = data_for_grouping.take([0, 4, 7])

is_bool = data_for_grouping.dtype._is_boolean
if is_bool:
# only 2 unique values
expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 0], dtype=np.intp)
expected_uniques = data_for_grouping.take([0, 4])
else:
expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 2], dtype=np.intp)
expected_uniques = data_for_grouping.take([0, 4, 7])

tm.assert_numpy_array_equal(codes, expected_codes)
self.assert_extension_array_equal(uniques, expected_uniques)
Expand Down Expand Up @@ -457,6 +472,9 @@ def test_hash_pandas_object_works(self, data, as_frame):
self.assert_equal(a, b)

def test_searchsorted(self, data_for_sorting, as_series):
if data_for_sorting.dtype._is_boolean:
return self._test_searchsorted_bool_dtypes(data_for_sorting, as_series)

b, c, a = data_for_sorting
arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c]

Expand All @@ -480,6 +498,32 @@ def test_searchsorted(self, data_for_sorting, as_series):
sorter = np.array([1, 2, 0])
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0

def _test_searchsorted_bool_dtypes(self, data_for_sorting, as_series):
# We call this from test_searchsorted in cases where we have a
# boolean-like dtype. The non-bool test assumes we have more than 2
# unique values.
dtype = data_for_sorting.dtype
data_for_sorting = pd.array([True, False], dtype=dtype)
b, a = data_for_sorting
arr = type(data_for_sorting)._from_sequence([a, b])

if as_series:
arr = pd.Series(arr)
assert arr.searchsorted(a) == 0
assert arr.searchsorted(a, side="right") == 1

assert arr.searchsorted(b) == 1
assert arr.searchsorted(b, side="right") == 2

result = arr.searchsorted(arr.take([0, 1]))
expected = np.array([0, 1], dtype=np.intp)

tm.assert_numpy_array_equal(result, expected)

# sorter
sorter = np.array([1, 0])
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0

def test_where_series(self, data, na_value, as_frame):
assert data[0] != data[1]
cls = type(data)
Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/extension/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def data_for_sorting():
This should be three items [B, C, A] with
A < B < C
For boolean dtypes (for which there are only 2 values available),
set B=C=True
"""
raise NotImplementedError

Expand Down Expand Up @@ -117,7 +120,10 @@ def data_for_grouping():
Expected to be like [B, B, NA, NA, A, A, B, C]
Where A < B < C and NA is missing
Where A < B < C and NA is missing.
If a dtype has _is_boolean = True, i.e. only 2 unique non-NA entries,
then set C=B.
"""
raise NotImplementedError

Expand Down
60 changes: 1 addition & 59 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,38 +587,6 @@ def test_reduce_series(


class TestBaseGroupby(base.BaseGroupbyTests):
def test_groupby_extension_no_sort(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_groupby_extension_no_sort(data_for_grouping)

def test_groupby_extension_transform(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_groupby_extension_transform(data_for_grouping)

@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=ValueError,
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_groupby_extension_agg(as_index, data_for_grouping)

def test_in_numeric_groupby(self, data_for_grouping):
dtype = data_for_grouping.dtype
if is_string_dtype(dtype):
Expand Down Expand Up @@ -845,13 +813,7 @@ def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
request.node.add_marker(
pytest.mark.xfail(
reason=f"No pyarrow kernel for {pa_dtype}",
Expand Down Expand Up @@ -888,16 +850,6 @@ def test_argreduce_series(
data_missing_for_sorting, op_name, skipna, expected
)

def test_factorize(self, data_for_grouping, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_factorize(data_for_grouping)

_combine_le_expected_dtype = "bool[pyarrow]"

def test_combine_add(self, data_repeated, request):
Expand All @@ -913,16 +865,6 @@ def test_combine_add(self, data_repeated, request):
else:
super().test_combine_add(data_repeated)

def test_searchsorted(self, data_for_sorting, as_series, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_searchsorted(data_for_sorting, as_series)

def test_basic_equals(self, data):
# https://github.com/pandas-dev/pandas/issues/34660
assert pd.Series(data).equals(pd.Series(data))
Expand Down
Loading

0 comments on commit 2bb3557

Please sign in to comment.