Skip to content

Commit

Permalink
Simplify and fix overloads for methods with inplace parameter (#1105)
Browse files Browse the repository at this point in the history
The return value of these methods varies on the `inplace` argument.
  • Loading branch information
brianhelba authored Feb 14, 2025
1 parent 2986c87 commit e5ce0f9
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 284 deletions.
159 changes: 49 additions & 110 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -779,19 +779,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
errors: IgnoreRaise = ...,
) -> Self: ...
@overload
def rename(
self,
mapper: Renamer | None = ...,
*,
index: Renamer | None = ...,
columns: Renamer | None = ...,
axis: Axis | None = ...,
copy: bool = ...,
inplace: bool = ...,
level: Level | None = ...,
errors: IgnoreRaise = ...,
) -> Self | None: ...
@overload
def fillna(
self,
value: Scalar | NAType | dict | Series | DataFrame | None = ...,
Expand All @@ -812,25 +799,15 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
inplace: Literal[False] = ...,
) -> Self: ...
@overload
def fillna(
self,
value: Scalar | NAType | dict | Series | DataFrame | None = ...,
*,
axis: Axis | None = ...,
inplace: _bool | None = ...,
limit: int = ...,
downcast: dict | None = ...,
) -> Self | None: ...
@overload
def replace(
self,
to_replace=...,
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
*,
inplace: Literal[True],
limit: int | None = ...,
regex=...,
method: ReplaceMethod = ...,
inplace: Literal[True],
) -> None: ...
@overload
def replace(
Expand All @@ -843,17 +820,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
regex=...,
method: ReplaceMethod = ...,
) -> Self: ...
@overload
def replace(
self,
to_replace=...,
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
*,
inplace: _bool | None = ...,
limit: int | None = ...,
regex=...,
method: ReplaceMethod = ...,
) -> Self | None: ...
def shift(
self,
periods: int = ...,
Expand Down Expand Up @@ -919,18 +885,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
allow_duplicates: _bool = ...,
names: Hashable | Sequence[Hashable] = ...,
) -> Self: ...
@overload
def reset_index(
self,
level: Level | Sequence[Level] = ...,
*,
drop: _bool = ...,
inplace: _bool | None = ...,
col_level: int | _str = ...,
col_fill: Hashable = ...,
allow_duplicates: _bool = ...,
names: Hashable | Sequence[Hashable] = ...,
) -> Self | None: ...
def isna(self) -> Self: ...
def isnull(self) -> Self: ...
def notna(self) -> Self: ...
Expand Down Expand Up @@ -958,17 +912,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
ignore_index: _bool = ...,
) -> Self: ...
@overload
def dropna(
self,
*,
axis: Axis = ...,
how: Literal["any", "all"] = ...,
thresh: int | None = ...,
subset: ListLikeU | Scalar | None = ...,
inplace: _bool | None = ...,
ignore_index: _bool = ...,
) -> Self | None: ...
@overload
def drop_duplicates(
self,
subset: Hashable | Iterable[Hashable] | None = ...,
Expand All @@ -986,15 +929,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
inplace: Literal[False] = ...,
ignore_index: _bool = ...,
) -> Self: ...
@overload
def drop_duplicates(
self,
subset: Hashable | Iterable[Hashable] | None = ...,
*,
keep: NaPosition | _bool = ...,
inplace: _bool = ...,
ignore_index: _bool = ...,
) -> Self | None: ...
def duplicated(
self,
subset: Hashable | Iterable[Hashable] | None = ...,
Expand Down Expand Up @@ -1027,19 +961,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
key: Callable | None = ...,
) -> Self: ...
@overload
def sort_values(
self,
by: _str | Sequence[_str],
*,
axis: Axis = ...,
ascending: _bool | Sequence[_bool] = ...,
inplace: _bool | None = ...,
kind: SortKind = ...,
na_position: NaPosition = ...,
ignore_index: _bool = ...,
key: Callable | None = ...,
) -> Self | None: ...
@overload
def sort_index(
self,
*,
Expand Down Expand Up @@ -1068,20 +989,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
key: Callable | None = ...,
) -> Self: ...
@overload
def sort_index(
self,
*,
axis: Axis = ...,
level: Level | list[int] | list[_str] | None = ...,
ascending: _bool | Sequence[_bool] = ...,
inplace: _bool | None = ...,
kind: SortKind = ...,
na_position: NaPosition = ...,
sort_remaining: _bool = ...,
ignore_index: _bool = ...,
key: Callable | None = ...,
) -> Self | None: ...
@overload
def value_counts(
self,
subset: Sequence[Hashable] | None = ...,
Expand Down Expand Up @@ -1824,13 +1731,24 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
limit_area: Literal["inside", "outside"] | None = ...,
downcast: dict | None = ...,
) -> Self: ...
@overload
def clip(
self,
lower: float | AnyArrayLike | None = ...,
upper: float | AnyArrayLike | None = ...,
*,
axis: Axis | None = ...,
inplace: _bool = ...,
inplace: Literal[True],
**kwargs,
) -> None: ...
@overload
def clip(
self,
lower: float | AnyArrayLike | None = ...,
upper: float | AnyArrayLike | None = ...,
*,
axis: Axis | None = ...,
inplace: Literal[False] = ...,
**kwargs,
) -> Self: ...
def copy(self, deep: _bool = ...) -> Self: ...
Expand Down Expand Up @@ -1963,19 +1881,6 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
inplace: Literal[False] = ...,
**kwargs,
) -> Self: ...
@overload
def interpolate(
self,
method: InterpolateOptions = ...,
*,
axis: Axis = ...,
limit: int | None = ...,
inplace: _bool | None = ...,
limit_direction: Literal["forward", "backward", "both"] = ...,
limit_area: Literal["inside", "outside"] | None = ...,
downcast: Literal["infer"] | None = ...,
**kwargs,
) -> Self | None: ...
def keys(self) -> Index: ...
def kurt(
self,
Expand All @@ -1997,6 +1902,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def last_valid_index(self) -> Scalar: ...
def le(self, other, axis: Axis = ..., level: Level | None = ...) -> Self: ...
def lt(self, other, axis: Axis = ..., level: Level | None = ...) -> Self: ...
@overload
def mask(
self,
cond: (
Expand All @@ -2008,7 +1914,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
),
other: Scalar | Series[S1] | DataFrame | Callable | NAType | None = ...,
*,
inplace: _bool = ...,
inplace: Literal[True],
axis: Axis | None = ...,
level: Level | None = ...,
) -> None: ...
@overload
def mask(
self,
cond: (
Series
| DataFrame
| np.ndarray
| Callable[[DataFrame], DataFrame]
| Callable[[Any], _bool]
),
other: Scalar | Series[S1] | DataFrame | Callable | NAType | None = ...,
*,
inplace: Literal[False] = ...,
axis: Axis | None = ...,
level: Level | None = ...,
) -> Self: ...
Expand Down Expand Up @@ -2470,6 +2392,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
numeric_only: _bool = ...,
**kwargs,
) -> Series: ...
@overload
def where(
self,
cond: (
Expand All @@ -2481,7 +2404,23 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
),
other=...,
*,
inplace: _bool = ...,
inplace: Literal[True],
axis: Axis | None = ...,
level: Level | None = ...,
) -> None: ...
@overload
def where(
self,
cond: (
Series
| DataFrame
| np.ndarray
| Callable[[DataFrame], DataFrame]
| Callable[[Any], _bool]
),
other=...,
*,
inplace: Literal[False] = ...,
axis: Axis | None = ...,
level: Level | None = ...,
) -> Self: ...
Expand Down
Loading

0 comments on commit e5ce0f9

Please sign in to comment.