diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index c4ee25c4b..56b83e176 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -547,7 +547,8 @@ S1 = TypeVar( | Period | Interval | CategoricalDtype - | BaseOffset, + | BaseOffset + | list[str], ) S2 = TypeVar( @@ -566,7 +567,8 @@ S2 = TypeVar( | Period | Interval | CategoricalDtype - | BaseOffset, + | BaseOffset + | list[str], ) IndexingInt: TypeAlias = ( diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 8f44bc5e3..763aae027 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -261,7 +261,9 @@ class Index(IndexOpsMixin[S1]): **kwargs, ) -> Self: ... @property - def str(self) -> StringMethods[Self, MultiIndex, np_ndarray_bool]: ... + def str( + self, + ) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ... def is_(self, other) -> bool: ... def __len__(self) -> int: ... def __array__(self, dtype=...) -> np.ndarray: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 2907cec00..b91d3842b 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -252,6 +252,26 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: bool = ..., ) -> Series[Any]: ... @overload + def __new__( + cls, + data: Sequence[list[str]], + index: Axes | None = ..., + *, + dtype: Dtype = ..., + name: Hashable = ..., + copy: bool = ..., + ) -> Series[list[str]]: ... + @overload + def __new__( + cls, + data: Sequence[str], + index: Axes | None = ..., + *, + dtype: Dtype = ..., + name: Hashable = ..., + copy: bool = ..., + ) -> Series[str]: ... + @overload def __new__( cls, data: ( @@ -1199,7 +1219,9 @@ class Series(IndexOpsMixin[S1], NDFrame): ) -> Series[S1]: ... def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ... @property - def str(self) -> StringMethods[Series, DataFrame, Series[bool]]: ... + def str( + self, + ) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 7e0dc880a..b952ced0d 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -15,6 +15,7 @@ import numpy as np import pandas as pd from pandas import ( DataFrame, + Index, MultiIndex, Series, ) @@ -28,10 +29,12 @@ from pandas._typing import ( # The _TS type is what is used for the result of str.split with expand=True _TS = TypeVar("_TS", DataFrame, MultiIndex) +# The _TS2 type is what is used for the result of str.split with expand=False +_TS2 = TypeVar("_TS2", Series[list[str]], Index[list[str]]) # The _TM type is what is used for the result of str.match _TM = TypeVar("_TM", Series[bool], np_ndarray_bool) -class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]): +class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def __init__(self, data: T) -> None: ... def __getitem__(self, key: slice | int) -> T: ... def __iter__(self) -> T: ... @@ -66,12 +69,19 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM]): ) -> _TS: ... @overload def split( - self, pat: str = ..., *, n: int = ..., expand: bool = ..., regex: bool = ... - ) -> T: ... + self, + pat: str = ..., + *, + n: int = ..., + expand: Literal[False] = ..., + regex: bool = ..., + ) -> _TS2: ... @overload def rsplit(self, pat: str = ..., *, n: int = ..., expand: Literal[True]) -> _TS: ... @overload - def rsplit(self, pat: str = ..., *, n: int = ..., expand: bool = ...) -> T: ... + def rsplit( + self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ... + ) -> _TS2: ... @overload def partition(self, sep: str = ...) -> pd.DataFrame: ... @overload diff --git a/tests/test_indexes.py b/tests/test_indexes.py index aab49c405..468908ad5 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -111,8 +111,25 @@ def test_difference_none() -> None: def test_str_split() -> None: # GH 194 ind = pd.Index(["a-b", "c-d"]) - check(assert_type(ind.str.split("-"), "pd.Index[str]"), pd.Index) + check(assert_type(ind.str.split("-"), "pd.Index[list[str]]"), pd.Index, list) check(assert_type(ind.str.split("-", expand=True), pd.MultiIndex), pd.MultiIndex) + check( + assert_type(ind.str.split("-", expand=False), "pd.Index[list[str]]"), + pd.Index, + list, + ) + + +def test_str_rsplit() -> None: + # GH 1074 + ind = pd.Index(["a-b", "c-d"]) + check(assert_type(ind.str.rsplit("-"), "pd.Index[list[str]]"), pd.Index, list) + check(assert_type(ind.str.rsplit("-", expand=True), pd.MultiIndex), pd.MultiIndex) + check( + assert_type(ind.str.rsplit("-", expand=False), "pd.Index[list[str]]"), + pd.Index, + list, + ) def test_str_match() -> None: diff --git a/tests/test_series.py b/tests/test_series.py index f35fd20aa..1aca73c50 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1548,14 +1548,24 @@ def test_string_accessors(): check(assert_type(s.str.rindex("p"), pd.Series), pd.Series) check(assert_type(s.str.rjust(80), pd.Series), pd.Series) check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.rsplit("a"), pd.Series), pd.Series) + check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list) check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"), + pd.Series, + list, + ) check(assert_type(s.str.rstrip(), pd.Series), pd.Series) check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series) check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series) - check(assert_type(s.str.split("a"), pd.Series), pd.Series) + check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list) # GH 194 check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"), + pd.Series, + list, + ) check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_) check( assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"),