Skip to content

Commit

Permalink
ENH: Add dtype argument to StringMethods get_dummies() (#59577)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronchucarroll authored Sep 9, 2024
1 parent 83fd9ba commit 715585d
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 27 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Other enhancements
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)
- Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`)
- Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`)
- Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`)
Expand Down
15 changes: 13 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_list_like,
is_numeric_dtype,
is_scalar,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -2475,7 +2476,9 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self:
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
if dtype is None:
dtype = np.bool_
split = pc.split_pattern(self._pa_array, sep)
flattened_values = pc.list_flatten(split)
uniques = flattened_values.unique()
Expand All @@ -2485,7 +2488,15 @@ def _str_get_dummies(self, sep: str = "|"):
n_cols = len(uniques)
indices = pc.index_in(flattened_values, uniques_sorted).to_numpy()
indices = indices + np.arange(n_rows).repeat(lengths) * n_cols
dummies = np.zeros(n_rows * n_cols, dtype=np.bool_)
_dtype = pandas_dtype(dtype)
dummies_dtype: NpDtype
if isinstance(_dtype, np.dtype):
dummies_dtype = _dtype
else:
dummies_dtype = np.bool_
dummies = np.zeros(n_rows * n_cols, dtype=dummies_dtype)
if dtype == str:
dummies[:] = False
dummies[indices] = True
dummies = dummies.reshape((n_rows, n_cols))
result = type(self)(pa.array(list(dummies)))
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,11 +2681,11 @@ def _str_map(
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
# sep may not be in categories. Just bail on this.
from pandas.core.arrays import NumpyExtensionArray

return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype)

# ------------------------------------------------------------------------
# GroupBy Methods
Expand Down
19 changes: 15 additions & 4 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ArrayLike,
AxisInt,
Dtype,
NpDtype,
Scalar,
Self,
npt,
Expand Down Expand Up @@ -425,12 +426,22 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
return super()._str_find(sub, start, end)
return ArrowStringArrayMixin._str_find(self, sub, start, end)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
if dtype is None:
dtype = np.int64
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(
sep, dtype
)
if len(labels) == 0:
return np.empty(shape=(0, 0), dtype=np.int64), labels
return np.empty(shape=(0, 0), dtype=dtype), labels
dummies = np.vstack(dummies_pa.to_numpy())
return dummies.astype(np.int64, copy=False), labels
_dtype = pandas_dtype(dtype)
dummies_dtype: NpDtype
if isinstance(_dtype, np.dtype):
dummies_dtype = _dtype
else:
dummies_dtype = np.bool_
return dummies.astype(dummies_dtype, copy=False), labels

def _convert_int_result(self, result):
if self.dtype.na_value is np.nan:
Expand Down
27 changes: 25 additions & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pandas.core.dtypes.common import (
ensure_object,
is_bool_dtype,
is_extension_array_dtype,
is_integer,
is_list_like,
is_object_dtype,
Expand Down Expand Up @@ -54,6 +55,8 @@
Iterator,
)

from pandas._typing import NpDtype

from pandas import (
DataFrame,
Index,
Expand Down Expand Up @@ -2431,7 +2434,11 @@ def wrap(
return self._wrap_result(result)

@forbid_nonstring_types(["bytes"])
def get_dummies(self, sep: str = "|"):
def get_dummies(
self,
sep: str = "|",
dtype: NpDtype | None = None,
):
"""
Return DataFrame of dummy/indicator variables for Series.
Expand All @@ -2442,6 +2449,8 @@ def get_dummies(self, sep: str = "|"):
----------
sep : str, default "|"
String to split on.
dtype : dtype, default np.int64
Data type for new columns. Only a single dtype is allowed.
Returns
-------
Expand All @@ -2466,10 +2475,24 @@ def get_dummies(self, sep: str = "|"):
0 1 1 0
1 0 0 0
2 1 0 1
>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool)
a b c
0 True True False
1 False False False
2 True False True
"""
from pandas.core.frame import DataFrame

# we need to cast to Series of strings as only that has all
# methods available for making the dummies...
result, name = self._data.array._str_get_dummies(sep)
result, name = self._data.array._str_get_dummies(sep, dtype)
if is_extension_array_dtype(dtype) or isinstance(dtype, ArrowDtype):
return self._wrap_result(
DataFrame(result, columns=name, dtype=dtype),
name=name,
returns_string=False,
)
return self._wrap_result(
result,
name=name,
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re

from pandas._typing import (
NpDtype,
Scalar,
Self,
)
Expand Down Expand Up @@ -163,7 +164,7 @@ def _str_wrap(self, width: int, **kwargs):
pass

@abc.abstractmethod
def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
pass

@abc.abstractmethod
Expand Down
13 changes: 11 additions & 2 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas._libs.ops as libops
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.missing import isna

from pandas.core.strings.base import BaseStringArrayMethods
Expand Down Expand Up @@ -398,9 +399,11 @@ def _str_wrap(self, width: int, **kwargs):
tw = textwrap.TextWrapper(**kwargs)
return self._str_map(lambda s: "\n".join(tw.wrap(s)))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None):
from pandas import Series

if dtype is None:
dtype = np.int64
arr = Series(self).fillna("")
try:
arr = sep + arr + sep
Expand All @@ -412,7 +415,13 @@ def _str_get_dummies(self, sep: str = "|"):
tags.update(ts)
tags2 = sorted(tags - {""})

dummies = np.empty((len(arr), len(tags2)), dtype=np.int64)
_dtype = pandas_dtype(dtype)
dummies_dtype: NpDtype
if isinstance(_dtype, np.dtype):
dummies_dtype = _dtype
else:
dummies_dtype = np.bool_
dummies = np.empty((len(arr), len(tags2)), dtype=dummies_dtype)

def _isin(test_elements: str, element: str) -> bool:
return element in test_elements
Expand Down
99 changes: 85 additions & 14 deletions pandas/tests/strings/test_get_dummies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import (
DataFrame,
Expand All @@ -8,6 +11,11 @@
_testing as tm,
)

try:
import pyarrow as pa
except ImportError:
pa = None


def test_get_dummies(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
Expand All @@ -32,22 +40,85 @@ def test_get_dummies_index():
tm.assert_index_equal(result, expected)


def test_get_dummies_with_name_dummy(any_string_dtype):
# GH 12180
# Dummies named 'name' should work as expected
s = Series(["a", "b,name", "b"], dtype=any_string_dtype)
result = s.str.get_dummies(",")
expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"])
# GH#47872
@pytest.mark.parametrize(
"dtype",
[
np.uint8,
np.int16,
np.uint16,
np.int32,
np.uint32,
np.int64,
np.uint64,
bool,
"Int8",
"Int16",
"Int32",
"Int64",
"boolean",
],
)
def test_get_dummies_with_dtype(any_string_dtype, dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype=dtype)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=dtype
)
tm.assert_frame_equal(result, expected)


def test_get_dummies_with_name_dummy_index():
# GH 12180
# Dummies named 'name' should work as expected
idx = Index(["a|b", "name|c", "b|name"])
result = idx.str.get_dummies("|")
# GH#47872
@td.skip_if_no("pyarrow")
@pytest.mark.parametrize(
"dtype",
[
"int8[pyarrow]",
"uint8[pyarrow]",
"int16[pyarrow]",
"uint16[pyarrow]",
"int32[pyarrow]",
"uint32[pyarrow]",
"int64[pyarrow]",
"uint64[pyarrow]",
"bool[pyarrow]",
],
)
def test_get_dummies_with_pyarrow_dtype(any_string_dtype, dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype=dtype)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
columns=list("abc"),
dtype=dtype,
)
tm.assert_frame_equal(result, expected)

expected = MultiIndex.from_tuples(
[(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name")

# GH#47872
def test_get_dummies_with_str_dtype(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype=str)
expected = DataFrame(
[["T", "T", "F"], ["T", "F", "T"], ["F", "F", "F"]],
columns=list("abc"),
dtype=str,
)
tm.assert_index_equal(result, expected)
tm.assert_frame_equal(result, expected)


# GH#47872
@td.skip_if_no("pyarrow")
def test_get_dummies_with_pa_str_dtype(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies("|", dtype="str[pyarrow]")
expected = DataFrame(
[
["true", "true", "false"],
["true", "false", "true"],
["false", "false", "false"],
],
columns=list("abc"),
dtype="str[pyarrow]",
)
tm.assert_frame_equal(result, expected)

0 comments on commit 715585d

Please sign in to comment.