From 715585de0d66383c51ce290ad6b18a036254d007 Mon Sep 17 00:00:00 2001 From: aaronchucarroll <120818400+aaronchucarroll@users.noreply.github.com> Date: Mon, 9 Sep 2024 17:38:04 -0400 Subject: [PATCH] ENH: Add dtype argument to StringMethods get_dummies() (#59577) --- doc/source/whatsnew/v3.0.0.rst | 1 + pandas/core/arrays/arrow/array.py | 15 +++- pandas/core/arrays/categorical.py | 4 +- pandas/core/arrays/string_arrow.py | 19 ++++- pandas/core/strings/accessor.py | 27 ++++++- pandas/core/strings/base.py | 3 +- pandas/core/strings/object_array.py | 13 +++- pandas/tests/strings/test_get_dummies.py | 99 ++++++++++++++++++++---- 8 files changed, 154 insertions(+), 27 deletions(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 9a29ff4d49966..819318e119668 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -55,6 +55,7 @@ Other enhancements - :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`) - :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`) - :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`) +- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`) - Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`) - Restore support for reading Stata 104-format and enable reading 103-format dta files (:issue:`58554`) - Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 15f9ba611a642..4edf464be74f1 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -41,6 +41,7 @@ is_list_like, is_numeric_dtype, is_scalar, + pandas_dtype, ) from pandas.core.dtypes.dtypes import DatetimeTZDtype from pandas.core.dtypes.missing import isna @@ -2475,7 +2476,9 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self: result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_get_dummies(self, sep: str = "|"): + def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): + if dtype is None: + dtype = np.bool_ split = pc.split_pattern(self._pa_array, sep) flattened_values = pc.list_flatten(split) uniques = flattened_values.unique() @@ -2485,7 +2488,15 @@ def _str_get_dummies(self, sep: str = "|"): n_cols = len(uniques) indices = pc.index_in(flattened_values, uniques_sorted).to_numpy() indices = indices + np.arange(n_rows).repeat(lengths) * n_cols - dummies = np.zeros(n_rows * n_cols, dtype=np.bool_) + _dtype = pandas_dtype(dtype) + dummies_dtype: NpDtype + if isinstance(_dtype, np.dtype): + dummies_dtype = _dtype + else: + dummies_dtype = np.bool_ + dummies = np.zeros(n_rows * n_cols, dtype=dummies_dtype) + if dtype == str: + dummies[:] = False dummies[indices] = True dummies = dummies.reshape((n_rows, n_cols)) result = type(self)(pa.array(list(dummies))) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index c613a345686cc..8e0225b31e17b 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2681,11 +2681,11 @@ def _str_map( result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype) return take_nd(result, codes, fill_value=na_value) - def _str_get_dummies(self, sep: str = "|"): + def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): # sep may not be in categories. Just bail on this. from pandas.core.arrays import NumpyExtensionArray - return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep) + return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep, dtype) # ------------------------------------------------------------------------ # GroupBy Methods diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 1e5adf106752f..fa8c662b68f3c 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -56,6 +56,7 @@ ArrayLike, AxisInt, Dtype, + NpDtype, Scalar, Self, npt, @@ -425,12 +426,22 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): return super()._str_find(sub, start, end) return ArrowStringArrayMixin._str_find(self, sub, start, end) - def _str_get_dummies(self, sep: str = "|"): - dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) + def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): + if dtype is None: + dtype = np.int64 + dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies( + sep, dtype + ) if len(labels) == 0: - return np.empty(shape=(0, 0), dtype=np.int64), labels + return np.empty(shape=(0, 0), dtype=dtype), labels dummies = np.vstack(dummies_pa.to_numpy()) - return dummies.astype(np.int64, copy=False), labels + _dtype = pandas_dtype(dtype) + dummies_dtype: NpDtype + if isinstance(_dtype, np.dtype): + dummies_dtype = _dtype + else: + dummies_dtype = np.bool_ + return dummies.astype(dummies_dtype, copy=False), labels def _convert_int_result(self, result): if self.dtype.na_value is np.nan: diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index bdb88e981bcda..6d10365a1b968 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -26,6 +26,7 @@ from pandas.core.dtypes.common import ( ensure_object, is_bool_dtype, + is_extension_array_dtype, is_integer, is_list_like, is_object_dtype, @@ -54,6 +55,8 @@ Iterator, ) + from pandas._typing import NpDtype + from pandas import ( DataFrame, Index, @@ -2431,7 +2434,11 @@ def wrap( return self._wrap_result(result) @forbid_nonstring_types(["bytes"]) - def get_dummies(self, sep: str = "|"): + def get_dummies( + self, + sep: str = "|", + dtype: NpDtype | None = None, + ): """ Return DataFrame of dummy/indicator variables for Series. @@ -2442,6 +2449,8 @@ def get_dummies(self, sep: str = "|"): ---------- sep : str, default "|" String to split on. + dtype : dtype, default np.int64 + Data type for new columns. Only a single dtype is allowed. Returns ------- @@ -2466,10 +2475,24 @@ def get_dummies(self, sep: str = "|"): 0 1 1 0 1 0 0 0 2 1 0 1 + + >>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool) + a b c + 0 True True False + 1 False False False + 2 True False True """ + from pandas.core.frame import DataFrame + # we need to cast to Series of strings as only that has all # methods available for making the dummies... - result, name = self._data.array._str_get_dummies(sep) + result, name = self._data.array._str_get_dummies(sep, dtype) + if is_extension_array_dtype(dtype) or isinstance(dtype, ArrowDtype): + return self._wrap_result( + DataFrame(result, columns=name, dtype=dtype), + name=name, + returns_string=False, + ) return self._wrap_result( result, name=name, diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index 1281a03e297f9..97d906e3df077 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -16,6 +16,7 @@ import re from pandas._typing import ( + NpDtype, Scalar, Self, ) @@ -163,7 +164,7 @@ def _str_wrap(self, width: int, **kwargs): pass @abc.abstractmethod - def _str_get_dummies(self, sep: str = "|"): + def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): pass @abc.abstractmethod diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index c6b18d7049c57..6211c7b528db9 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -18,6 +18,7 @@ import pandas._libs.ops as libops from pandas.util._exceptions import find_stack_level +from pandas.core.dtypes.common import pandas_dtype from pandas.core.dtypes.missing import isna from pandas.core.strings.base import BaseStringArrayMethods @@ -398,9 +399,11 @@ def _str_wrap(self, width: int, **kwargs): tw = textwrap.TextWrapper(**kwargs) return self._str_map(lambda s: "\n".join(tw.wrap(s))) - def _str_get_dummies(self, sep: str = "|"): + def _str_get_dummies(self, sep: str = "|", dtype: NpDtype | None = None): from pandas import Series + if dtype is None: + dtype = np.int64 arr = Series(self).fillna("") try: arr = sep + arr + sep @@ -412,7 +415,13 @@ def _str_get_dummies(self, sep: str = "|"): tags.update(ts) tags2 = sorted(tags - {""}) - dummies = np.empty((len(arr), len(tags2)), dtype=np.int64) + _dtype = pandas_dtype(dtype) + dummies_dtype: NpDtype + if isinstance(_dtype, np.dtype): + dummies_dtype = _dtype + else: + dummies_dtype = np.bool_ + dummies = np.empty((len(arr), len(tags2)), dtype=dummies_dtype) def _isin(test_elements: str, element: str) -> bool: return element in test_elements diff --git a/pandas/tests/strings/test_get_dummies.py b/pandas/tests/strings/test_get_dummies.py index 31386e4e342ae..0656f505dc745 100644 --- a/pandas/tests/strings/test_get_dummies.py +++ b/pandas/tests/strings/test_get_dummies.py @@ -1,4 +1,7 @@ import numpy as np +import pytest + +import pandas.util._test_decorators as td from pandas import ( DataFrame, @@ -8,6 +11,11 @@ _testing as tm, ) +try: + import pyarrow as pa +except ImportError: + pa = None + def test_get_dummies(any_string_dtype): s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) @@ -32,22 +40,85 @@ def test_get_dummies_index(): tm.assert_index_equal(result, expected) -def test_get_dummies_with_name_dummy(any_string_dtype): - # GH 12180 - # Dummies named 'name' should work as expected - s = Series(["a", "b,name", "b"], dtype=any_string_dtype) - result = s.str.get_dummies(",") - expected = DataFrame([[1, 0, 0], [0, 1, 1], [0, 1, 0]], columns=["a", "b", "name"]) +# GH#47872 +@pytest.mark.parametrize( + "dtype", + [ + np.uint8, + np.int16, + np.uint16, + np.int32, + np.uint32, + np.int64, + np.uint64, + bool, + "Int8", + "Int16", + "Int32", + "Int64", + "boolean", + ], +) +def test_get_dummies_with_dtype(any_string_dtype, dtype): + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) + result = s.str.get_dummies("|", dtype=dtype) + expected = DataFrame( + [[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=list("abc"), dtype=dtype + ) tm.assert_frame_equal(result, expected) -def test_get_dummies_with_name_dummy_index(): - # GH 12180 - # Dummies named 'name' should work as expected - idx = Index(["a|b", "name|c", "b|name"]) - result = idx.str.get_dummies("|") +# GH#47872 +@td.skip_if_no("pyarrow") +@pytest.mark.parametrize( + "dtype", + [ + "int8[pyarrow]", + "uint8[pyarrow]", + "int16[pyarrow]", + "uint16[pyarrow]", + "int32[pyarrow]", + "uint32[pyarrow]", + "int64[pyarrow]", + "uint64[pyarrow]", + "bool[pyarrow]", + ], +) +def test_get_dummies_with_pyarrow_dtype(any_string_dtype, dtype): + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) + result = s.str.get_dummies("|", dtype=dtype) + expected = DataFrame( + [[1, 1, 0], [1, 0, 1], [0, 0, 0]], + columns=list("abc"), + dtype=dtype, + ) + tm.assert_frame_equal(result, expected) - expected = MultiIndex.from_tuples( - [(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name") + +# GH#47872 +def test_get_dummies_with_str_dtype(any_string_dtype): + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) + result = s.str.get_dummies("|", dtype=str) + expected = DataFrame( + [["T", "T", "F"], ["T", "F", "T"], ["F", "F", "F"]], + columns=list("abc"), + dtype=str, ) - tm.assert_index_equal(result, expected) + tm.assert_frame_equal(result, expected) + + +# GH#47872 +@td.skip_if_no("pyarrow") +def test_get_dummies_with_pa_str_dtype(any_string_dtype): + s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype) + result = s.str.get_dummies("|", dtype="str[pyarrow]") + expected = DataFrame( + [ + ["true", "true", "false"], + ["true", "false", "true"], + ["false", "false", "false"], + ], + columns=list("abc"), + dtype="str[pyarrow]", + ) + tm.assert_frame_equal(result, expected)