Skip to content

Commit

Permalink
Backport PR pandas-dev#54752: REGR: groupby.count returning string dt…
Browse files Browse the repository at this point in the history
…ype instead of numeric for string input
  • Loading branch information
phofl authored and meeseeksmachine committed Aug 25, 2023
1 parent 4f66163 commit e121f78
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class providing the base-class of operations.
IntegerArray,
SparseArray,
)
from pandas.core.arrays.string_ import StringDtype
from pandas.core.base import (
PandasObject,
SelectionMixin,
Expand Down Expand Up @@ -2261,7 +2262,9 @@ def hfunc(bvalues: ArrayLike) -> ArrayLike:
return IntegerArray(
counted[0], mask=np.zeros(counted.shape[1], dtype=np.bool_)
)
elif isinstance(bvalues, ArrowExtensionArray):
elif isinstance(bvalues, ArrowExtensionArray) and not isinstance(
bvalues.dtype, StringDtype
):
return type(bvalues)._from_sequence(counted[0])
if is_series:
assert counted.ndim == 2
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/groupby/test_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,14 @@ def __eq__(self, other):
result = df.groupby("grp").count()
expected = DataFrame({"a": [2, 2]}, index=Index(list("ab"), name="grp"))
tm.assert_frame_equal(result, expected)


def test_count_arrow_string_array(any_string_dtype):
# GH#54751
pytest.importorskip("pyarrow")
df = DataFrame(
{"a": [1, 2, 3], "b": Series(["a", "b", "a"], dtype=any_string_dtype)}
)
result = df.groupby("a").count()
expected = DataFrame({"b": 1}, index=Index([1, 2, 3], name="a"))
tm.assert_frame_equal(result, expected)

0 comments on commit e121f78

Please sign in to comment.