Skip to content

Commit

Permalink
GH1074 Add type hint Series[list[str]] for Series.str.split with expa…
Browse files Browse the repository at this point in the history
…nd=False (#1075)

* GH1074 Add type hint Series[list[str]] for Series.str.split with expand=False

* Updates:

    - fix Index.str.split method return wrong result;
    - add test for Index.str.split method with expand=False;
    - return changes performed in pull request #1029.

* Update tests/test_indexes.py

Co-authored-by: Irv Lustig <irv@princeton.com>

* Update tests/test_series.py

Co-authored-by: Irv Lustig <irv@princeton.com>

* Update tests/test_series.py

Co-authored-by: Irv Lustig <irv@princeton.com>

* Updates:

    - combine two str.split overloads and keep only _TS and _TS2;
    - fix test_indexes.py test for test_str_split().

* pre-commit fixes

* Add type hints and tests for str.rsplit() for expand=False

---------

Co-authored-by: Irv Lustig <irv@princeton.com>
  • Loading branch information
pan-vlados and Dr-Irv authored Dec 18, 2024
1 parent 63dfe96 commit 109dc86
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 11 deletions.
6 changes: 4 additions & 2 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ S1 = TypeVar(
| Period
| Interval
| CategoricalDtype
| BaseOffset,
| BaseOffset
| list[str],
)

S2 = TypeVar(
Expand All @@ -566,7 +567,8 @@ S2 = TypeVar(
| Period
| Interval
| CategoricalDtype
| BaseOffset,
| BaseOffset
| list[str],
)

IndexingInt: TypeAlias = (
Expand Down
4 changes: 3 additions & 1 deletion pandas-stubs/core/indexes/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ class Index(IndexOpsMixin[S1]):
**kwargs,
) -> Self: ...
@property
def str(self) -> StringMethods[Self, MultiIndex, np_ndarray_bool]: ...
def str(
self,
) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ...
def is_(self, other) -> bool: ...
def __len__(self) -> int: ...
def __array__(self, dtype=...) -> np.ndarray: ...
Expand Down
24 changes: 23 additions & 1 deletion pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,26 @@ class Series(IndexOpsMixin[S1], NDFrame):
copy: bool = ...,
) -> Series[Any]: ...
@overload
def __new__(
cls,
data: Sequence[list[str]],
index: Axes | None = ...,
*,
dtype: Dtype = ...,
name: Hashable = ...,
copy: bool = ...,
) -> Series[list[str]]: ...
@overload
def __new__(
cls,
data: Sequence[str],
index: Axes | None = ...,
*,
dtype: Dtype = ...,
name: Hashable = ...,
copy: bool = ...,
) -> Series[str]: ...
@overload
def __new__(
cls,
data: (
Expand Down Expand Up @@ -1199,7 +1219,9 @@ class Series(IndexOpsMixin[S1], NDFrame):
) -> Series[S1]: ...
def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ...
@property
def str(self) -> StringMethods[Series, DataFrame, Series[bool]]: ...
def str(
self,
) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ...
@property
def dt(self) -> CombinedDatetimelikeProperties: ...
@property
Expand Down
18 changes: 14 additions & 4 deletions pandas-stubs/core/strings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import numpy as np
import pandas as pd
from pandas import (
DataFrame,
Index,
MultiIndex,
Series,
)
Expand All @@ -28,10 +29,12 @@ from pandas._typing import (

# The _TS type is what is used for the result of str.split with expand=True
_TS = TypeVar("_TS", DataFrame, MultiIndex)
# The _TS2 type is what is used for the result of str.split with expand=False
_TS2 = TypeVar("_TS2", Series[list[str]], Index[list[str]])
# The _TM type is what is used for the result of str.match
_TM = TypeVar("_TM", Series[bool], np_ndarray_bool)

class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]):
class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
def __init__(self, data: T) -> None: ...
def __getitem__(self, key: slice | int) -> T: ...
def __iter__(self) -> T: ...
Expand Down Expand Up @@ -66,12 +69,19 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]):
) -> _TS: ...
@overload
def split(
self, pat: str = ..., *, n: int = ..., expand: bool = ..., regex: bool = ...
) -> T: ...
self,
pat: str = ...,
*,
n: int = ...,
expand: Literal[False] = ...,
regex: bool = ...,
) -> _TS2: ...
@overload
def rsplit(self, pat: str = ..., *, n: int = ..., expand: Literal[True]) -> _TS: ...
@overload
def rsplit(self, pat: str = ..., *, n: int = ..., expand: bool = ...) -> T: ...
def rsplit(
self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ...
) -> _TS2: ...
@overload
def partition(self, sep: str = ...) -> pd.DataFrame: ...
@overload
Expand Down
19 changes: 18 additions & 1 deletion tests/test_indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,25 @@ def test_difference_none() -> None:
def test_str_split() -> None:
# GH 194
ind = pd.Index(["a-b", "c-d"])
check(assert_type(ind.str.split("-"), "pd.Index[str]"), pd.Index)
check(assert_type(ind.str.split("-"), "pd.Index[list[str]]"), pd.Index, list)
check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex)
check(
assert_type(ind.str.split("-", expand=False), "pd.Index[list[str]]"),
pd.Index,
list,
)


def test_str_rsplit() -> None:
# GH 1074
ind = pd.Index(["a-b", "c-d"])
check(assert_type(ind.str.rsplit("-"), "pd.Index[list[str]]"), pd.Index, list)
check(assert_type(ind.str.rsplit("-", expand=True), pd.MultiIndex), pd.MultiIndex)
check(
assert_type(ind.str.rsplit("-", expand=False), "pd.Index[list[str]]"),
pd.Index,
list,
)


def test_str_match() -> None:
Expand Down
14 changes: 12 additions & 2 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,14 +1548,24 @@ def test_string_accessors():
check(assert_type(s.str.rindex("p"), pd.Series), pd.Series)
check(assert_type(s.str.rjust(80), pd.Series), pd.Series)
check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame)
check(assert_type(s.str.rsplit("a"), pd.Series), pd.Series)
check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list)
check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame)
check(
assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"),
pd.Series,
list,
)
check(assert_type(s.str.rstrip(), pd.Series), pd.Series)
check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series)
check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series)
check(assert_type(s.str.split("a"), pd.Series), pd.Series)
check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list)
# GH 194
check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame)
check(
assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"),
pd.Series,
list,
)
check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_)
check(
assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"),
Expand Down

0 comments on commit 109dc86

Please sign in to comment.