From 1cec5766c14d03c3eb06a57ccdb0482af5ba0af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sat, 30 Dec 2023 18:50:31 -0500 Subject: [PATCH 1/8] TYP: mostly Hashtable and ArrowExtensionArray --- pandas/_libs/hashtable.pyi | 31 ++- pandas/_libs/hashtable_class_helper.pxi.in | 2 +- pandas/compat/pickle_compat.py | 9 +- pandas/core/algorithms.py | 23 ++- pandas/core/arrays/arrow/array.py | 218 ++++++++++++--------- 5 files changed, 170 insertions(+), 113 deletions(-) diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi index 555ec73acd9b2..3bb957812f0ed 100644 --- a/pandas/_libs/hashtable.pyi +++ b/pandas/_libs/hashtable.pyi @@ -2,6 +2,7 @@ from typing import ( Any, Hashable, Literal, + overload, ) import numpy as np @@ -180,18 +181,30 @@ class HashTable: na_value: object = ..., mask=..., ) -> npt.NDArray[np.intp]: ... + @overload def unique( self, values: np.ndarray, # np.ndarray[subclass-specific] - return_inverse: bool = ..., - mask=..., - ) -> ( - tuple[ - np.ndarray, # np.ndarray[subclass-specific] - npt.NDArray[np.intp], - ] - | np.ndarray - ): ... # np.ndarray[subclass-specific] + *, + return_inverse: Literal[False] = ..., + mask: None = ..., + ) -> np.ndarray: ... # np.ndarray[subclass-specific] + @overload + def unique( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + *, + return_inverse: Literal[True], + mask: None = ..., + ) -> tuple[np.ndarray, npt.NDArray[np.intp],]: ... # np.ndarray[subclass-specific] + @overload + def unique( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + *, + return_inverse: Literal[False] = ..., + mask: npt.NDArray[np.bool_], + ) -> tuple[np.ndarray, npt.NDArray[np.bool_],]: ... # np.ndarray[subclass-specific] def factorize( self, values: np.ndarray, # np.ndarray[subclass-specific] diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index c0723392496c1..988851ffe1ee1 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -755,7 +755,7 @@ cdef class {{name}}HashTable(HashTable): return uniques.to_array(), result_mask.to_array() return uniques.to_array() - def unique(self, const {{dtype}}_t[:] values, bint return_inverse=False, object mask=None): + def unique(self, const {{dtype}}_t[:] values, *, bint return_inverse=False, object mask=None): """ Calculate unique values and labels (no sorting!) diff --git a/pandas/compat/pickle_compat.py b/pandas/compat/pickle_compat.py index cd98087c06c18..ff589ebba4cf6 100644 --- a/pandas/compat/pickle_compat.py +++ b/pandas/compat/pickle_compat.py @@ -7,7 +7,10 @@ import copy import io import pickle as pkl -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, +) import numpy as np @@ -209,7 +212,7 @@ def load_newobj_ex(self) -> None: pass -def load(fh, encoding: str | None = None, is_verbose: bool = False): +def load(fh, encoding: str | None = None, is_verbose: bool = False) -> Any: """ Load a pickle, with a provided encoding, @@ -239,7 +242,7 @@ def loads( fix_imports: bool = True, encoding: str = "ASCII", errors: str = "strict", -): +) -> Any: """ Analogous to pickle._loads. """ diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 15a07da76d2f7..6f1ac75cebc0b 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -11,6 +11,7 @@ TYPE_CHECKING, Literal, cast, + overload, ) import warnings @@ -181,6 +182,20 @@ def _ensure_data(values: ArrayLike) -> np.ndarray: return ensure_object(values) +@overload +def _reconstruct_data( + values: ExtensionArray, dtype: DtypeObj, original: AnyArrayLike +) -> ExtensionArray: + ... + + +@overload +def _reconstruct_data( + values: np.ndarray, dtype: DtypeObj, original: AnyArrayLike +) -> np.ndarray: + ... + + def _reconstruct_data( values: ArrayLike, dtype: DtypeObj, original: AnyArrayLike ) -> ArrayLike: @@ -259,7 +274,9 @@ def _ensure_arraylike(values, func_name: str) -> ArrayLike: } -def _get_hashtable_algo(values: np.ndarray): +def _get_hashtable_algo( + values: np.ndarray, +) -> tuple[type[htable.HashTable], np.ndarray]: """ Parameters ---------- @@ -1550,7 +1567,9 @@ def safe_sort( hash_klass, values = _get_hashtable_algo(values) # type: ignore[arg-type] t = hash_klass(len(values)) t.map_locations(values) - sorter = ensure_platform_int(t.lookup(ordered)) + # error: Argument 1 to "lookup" of "HashTable" has incompatible type + # "ExtensionArray | ndarray[Any, Any] | Index | Series"; expected "ndarray" + sorter = ensure_platform_int(t.lookup(ordered)) # type: ignore[arg-type] if use_na_sentinel: # take_nd is faster, but only works for na_sentinels of -1 diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index b1164301e6d79..464b4d6f5c36b 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -10,6 +10,7 @@ Callable, Literal, cast, + overload, ) import unicodedata @@ -157,6 +158,7 @@ def floordiv_compat( if TYPE_CHECKING: from collections.abc import Sequence + from pandas._libs.missing import NAType from pandas._typing import ( ArrayLike, AxisInt, @@ -280,7 +282,9 @@ def __init__(self, values: pa.Array | pa.ChunkedArray) -> None: self._dtype = ArrowDtype(self._pa_array.type) @classmethod - def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): + def _from_sequence( + cls, scalars, *, dtype: Dtype | None = None, copy: bool = False + ) -> Self: """ Construct a new ExtensionArray from a sequence of scalars. """ @@ -292,7 +296,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal @classmethod def _from_sequence_of_strings( cls, strings, *, dtype: Dtype | None = None, copy: bool = False - ): + ) -> Self: """ Construct a new ExtensionArray from a sequence of strings. """ @@ -675,7 +679,7 @@ def __setstate__(self, state) -> None: state["_pa_array"] = pa.chunked_array(data) self.__dict__.update(state) - def _cmp_method(self, other, op): + def _cmp_method(self, other, op) -> ArrowExtensionArray: pc_func = ARROW_CMP_FUNCS[op.__name__] if isinstance( other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray) @@ -701,7 +705,7 @@ def _cmp_method(self, other, op): ) return ArrowExtensionArray(result) - def _evaluate_op_method(self, other, op, arrow_funcs): + def _evaluate_op_method(self, other, op, arrow_funcs) -> Self: pa_type = self._pa_array.type other = self._box_pa(other) @@ -752,7 +756,7 @@ def _evaluate_op_method(self, other, op, arrow_funcs): result = pc_func(self._pa_array, other) return type(self)(result) - def _logical_method(self, other, op): + def _logical_method(self, other, op) -> Self: # For integer types `^`, `|`, `&` are bitwise operators and return # integer types. Otherwise these are boolean ops. if pa.types.is_integer(self._pa_array.type): @@ -760,7 +764,7 @@ def _logical_method(self, other, op): else: return self._evaluate_op_method(other, op, ARROW_LOGICAL_FUNCS) - def _arith_method(self, other, op): + def _arith_method(self, other, op) -> Self: return self._evaluate_op_method(other, op, ARROW_ARITHMETIC_FUNCS) def equals(self, other) -> bool: @@ -825,7 +829,17 @@ def isna(self) -> npt.NDArray[np.bool_]: return self._pa_array.is_null().to_numpy() - def any(self, *, skipna: bool = True, **kwargs): + # error: Signature of "any" incompatible with supertype + # "ExtensionArraySupportsAnyAll" + @overload # type: ignore[override] + def any(self, *, skipna: Literal[True] = ..., **kwargs) -> bool: + ... + + @overload + def any(self, *, skipna: bool, **kwargs) -> bool | NAType: + ... + + def any(self, *, skipna: bool = True, **kwargs) -> bool | NAType: """ Return whether any element is truthy. @@ -883,7 +897,17 @@ def any(self, *, skipna: bool = True, **kwargs): """ return self._reduce("any", skipna=skipna, **kwargs) - def all(self, *, skipna: bool = True, **kwargs): + # error: Signature of "all" incompatible with supertype + # "ExtensionArraySupportsAnyAll" + @overload # type: ignore[override] + def all(self, *, skipna: Literal[True] = ..., **kwargs) -> bool: + ... + + @overload + def all(self, *, skipna: bool, **kwargs) -> bool | NAType: + ... + + def all(self, *, skipna: bool = True, **kwargs) -> bool | NAType: """ Return whether all elements are truthy. @@ -2027,7 +2051,7 @@ def _if_else( cond: npt.NDArray[np.bool_] | bool, left: ArrayLike | Scalar, right: ArrayLike | Scalar, - ): + ) -> pa.Array: """ Choose values based on a condition. @@ -2071,7 +2095,7 @@ def _replace_with_mask( values: pa.Array | pa.ChunkedArray, mask: npt.NDArray[np.bool_] | bool, replacements: ArrayLike | Scalar, - ): + ) -> pa.Array | pa.ChunkedArray: """ Replace items selected with a mask. @@ -2178,14 +2202,14 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]: for chunk in self._pa_array.iterchunks() ] - def _str_count(self, pat: str, flags: int = 0): + def _str_count(self, pat: str, flags: int = 0) -> Self: if flags: raise NotImplementedError(f"count not implemented with {flags=}") return type(self)(pc.count_substring_regex(self._pa_array, pat)) def _str_contains( self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True - ): + ) -> Self: if flags: raise NotImplementedError(f"contains not implemented with {flags=}") @@ -2198,7 +2222,7 @@ def _str_contains( result = result.fill_null(na) return type(self)(result) - def _str_startswith(self, pat: str | tuple[str, ...], na=None): + def _str_startswith(self, pat: str | tuple[str, ...], na=None) -> Self: if isinstance(pat, str): result = pc.starts_with(self._pa_array, pattern=pat) else: @@ -2215,7 +2239,7 @@ def _str_startswith(self, pat: str | tuple[str, ...], na=None): result = result.fill_null(na) return type(self)(result) - def _str_endswith(self, pat: str | tuple[str, ...], na=None): + def _str_endswith(self, pat: str | tuple[str, ...], na=None) -> Self: if isinstance(pat, str): result = pc.ends_with(self._pa_array, pattern=pat) else: @@ -2240,7 +2264,7 @@ def _str_replace( case: bool = True, flags: int = 0, regex: bool = True, - ): + ) -> Self: if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: raise NotImplementedError( "replace is not supported with a re.Pattern, callable repl, " @@ -2259,29 +2283,28 @@ def _str_replace( ) return type(self)(result) - def _str_repeat(self, repeats: int | Sequence[int]): + def _str_repeat(self, repeats: int | Sequence[int]) -> Self: if not isinstance(repeats, int): raise NotImplementedError( f"repeat is not implemented when repeats is {type(repeats).__name__}" ) - else: - return type(self)(pc.binary_repeat(self._pa_array, repeats)) + return type(self)(pc.binary_repeat(self._pa_array, repeats)) def _str_match( self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None - ): + ) -> Self: if not pat.startswith("^"): pat = f"^{pat}" return self._str_contains(pat, case, flags, na, regex=True) def _str_fullmatch( self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None - ): + ) -> Self: if not pat.endswith("$") or pat.endswith("//$"): pat = f"{pat}$" return self._str_match(pat, case, flags, na) - def _str_find(self, sub: str, start: int = 0, end: int | None = None): + def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: if start != 0 and end is not None: slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) result = pc.find_substring(slices, sub) @@ -2298,7 +2321,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None): ) return type(self)(result) - def _str_join(self, sep: str): + def _str_join(self, sep: str) -> Self: if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string( self._pa_array.type ): @@ -2308,19 +2331,19 @@ def _str_join(self, sep: str): result = self._pa_array return type(self)(pc.binary_join(result, sep)) - def _str_partition(self, sep: str, expand: bool): + def _str_partition(self, sep: str, expand: bool) -> Self: predicate = lambda val: val.partition(sep) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_rpartition(self, sep: str, expand: bool): + def _str_rpartition(self, sep: str, expand: bool) -> Self: predicate = lambda val: val.rpartition(sep) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None - ): + ) -> Self: if start is None: start = 0 if step is None: @@ -2329,57 +2352,57 @@ def _str_slice( pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step) ) - def _str_isalnum(self): + def _str_isalnum(self) -> Self: return type(self)(pc.utf8_is_alnum(self._pa_array)) - def _str_isalpha(self): + def _str_isalpha(self) -> Self: return type(self)(pc.utf8_is_alpha(self._pa_array)) - def _str_isdecimal(self): + def _str_isdecimal(self) -> Self: return type(self)(pc.utf8_is_decimal(self._pa_array)) - def _str_isdigit(self): + def _str_isdigit(self) -> Self: return type(self)(pc.utf8_is_digit(self._pa_array)) - def _str_islower(self): + def _str_islower(self) -> Self: return type(self)(pc.utf8_is_lower(self._pa_array)) - def _str_isnumeric(self): + def _str_isnumeric(self) -> Self: return type(self)(pc.utf8_is_numeric(self._pa_array)) - def _str_isspace(self): + def _str_isspace(self) -> Self: return type(self)(pc.utf8_is_space(self._pa_array)) - def _str_istitle(self): + def _str_istitle(self) -> Self: return type(self)(pc.utf8_is_title(self._pa_array)) - def _str_isupper(self): + def _str_isupper(self) -> Self: return type(self)(pc.utf8_is_upper(self._pa_array)) - def _str_len(self): + def _str_len(self) -> Self: return type(self)(pc.utf8_length(self._pa_array)) - def _str_lower(self): + def _str_lower(self) -> Self: return type(self)(pc.utf8_lower(self._pa_array)) - def _str_upper(self): + def _str_upper(self) -> Self: return type(self)(pc.utf8_upper(self._pa_array)) - def _str_strip(self, to_strip=None): + def _str_strip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_trim_whitespace(self._pa_array) else: result = pc.utf8_trim(self._pa_array, characters=to_strip) return type(self)(result) - def _str_lstrip(self, to_strip=None): + def _str_lstrip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_ltrim_whitespace(self._pa_array) else: result = pc.utf8_ltrim(self._pa_array, characters=to_strip) return type(self)(result) - def _str_rstrip(self, to_strip=None): + def _str_rstrip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_rtrim_whitespace(self._pa_array) else: @@ -2396,12 +2419,12 @@ def _str_removeprefix(self, prefix: str): result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_casefold(self): + def _str_casefold(self) -> Self: predicate = lambda val: val.casefold() result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_encode(self, encoding: str, errors: str = "strict"): + def _str_encode(self, encoding: str, errors: str = "strict") -> Self: predicate = lambda val: val.encode(encoding, errors) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) @@ -2421,7 +2444,7 @@ def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): else: return type(self)(pc.struct_field(result, [0])) - def _str_findall(self, pat: str, flags: int = 0): + def _str_findall(self, pat: str, flags: int = 0) -> Self: regex = re.compile(pat, flags=flags) predicate = lambda val: regex.findall(val) result = self._apply_elementwise(predicate) @@ -2443,22 +2466,22 @@ def _str_get_dummies(self, sep: str = "|"): result = type(self)(pa.array(list(dummies))) return result, uniques_sorted.to_pylist() - def _str_index(self, sub: str, start: int = 0, end: int | None = None): + def _str_index(self, sub: str, start: int = 0, end: int | None = None) -> Self: predicate = lambda val: val.index(sub, start, end) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_rindex(self, sub: str, start: int = 0, end: int | None = None): + def _str_rindex(self, sub: str, start: int = 0, end: int | None = None) -> Self: predicate = lambda val: val.rindex(sub, start, end) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_normalize(self, form: str): + def _str_normalize(self, form: str) -> Self: predicate = lambda val: unicodedata.normalize(form, val) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_rfind(self, sub: str, start: int = 0, end=None): + def _str_rfind(self, sub: str, start: int = 0, end=None) -> Self: predicate = lambda val: val.rfind(sub, start, end) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) @@ -2469,7 +2492,7 @@ def _str_split( n: int | None = -1, expand: bool = False, regex: bool | None = None, - ): + ) -> Self: if n in {-1, 0}: n = None if pat is None: @@ -2480,24 +2503,23 @@ def _str_split( split_func = functools.partial(pc.split_pattern, pattern=pat) return type(self)(split_func(self._pa_array, max_splits=n)) - def _str_rsplit(self, pat: str | None = None, n: int | None = -1): + def _str_rsplit(self, pat: str | None = None, n: int | None = -1) -> Self: if n in {-1, 0}: n = None if pat is None: return type(self)( pc.utf8_split_whitespace(self._pa_array, max_splits=n, reverse=True) ) - else: - return type(self)( - pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) - ) + return type(self)( + pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True) + ) - def _str_translate(self, table: dict[int, str]): + def _str_translate(self, table: dict[int, str]) -> Self: predicate = lambda val: val.translate(table) result = self._apply_elementwise(predicate) return type(self)(pa.chunked_array(result)) - def _str_wrap(self, width: int, **kwargs): + def _str_wrap(self, width: int, **kwargs) -> Self: kwargs["width"] = width tw = textwrap.TextWrapper(**kwargs) predicate = lambda val: "\n".join(tw.wrap(val)) @@ -2505,13 +2527,13 @@ def _str_wrap(self, width: int, **kwargs): return type(self)(pa.chunked_array(result)) @property - def _dt_days(self): + def _dt_days(self) -> Self: return type(self)( pa.array(self._to_timedeltaarray().days, from_pandas=True, type=pa.int32()) ) @property - def _dt_hours(self): + def _dt_hours(self) -> Self: return type(self)( pa.array( [ @@ -2523,7 +2545,7 @@ def _dt_hours(self): ) @property - def _dt_minutes(self): + def _dt_minutes(self) -> Self: return type(self)( pa.array( [ @@ -2535,7 +2557,7 @@ def _dt_minutes(self): ) @property - def _dt_seconds(self): + def _dt_seconds(self) -> Self: return type(self)( pa.array( self._to_timedeltaarray().seconds, from_pandas=True, type=pa.int32() @@ -2543,7 +2565,7 @@ def _dt_seconds(self): ) @property - def _dt_milliseconds(self): + def _dt_milliseconds(self) -> Self: return type(self)( pa.array( [ @@ -2555,7 +2577,7 @@ def _dt_milliseconds(self): ) @property - def _dt_microseconds(self): + def _dt_microseconds(self) -> Self: return type(self)( pa.array( self._to_timedeltaarray().microseconds, @@ -2565,25 +2587,25 @@ def _dt_microseconds(self): ) @property - def _dt_nanoseconds(self): + def _dt_nanoseconds(self) -> Self: return type(self)( pa.array( self._to_timedeltaarray().nanoseconds, from_pandas=True, type=pa.int32() ) ) - def _dt_to_pytimedelta(self): + def _dt_to_pytimedelta(self) -> np.ndarray: data = self._pa_array.to_pylist() if self._dtype.pyarrow_dtype.unit == "ns": data = [None if ts is None else ts.to_pytimedelta() for ts in data] return np.array(data, dtype=object) - def _dt_total_seconds(self): + def _dt_total_seconds(self) -> Self: return type(self)( pa.array(self._to_timedeltaarray().total_seconds(), from_pandas=True) ) - def _dt_as_unit(self, unit: str): + def _dt_as_unit(self, unit: str) -> Self: if pa.types.is_date(self.dtype.pyarrow_dtype): raise NotImplementedError("as_unit not implemented for date types") pd_array = self._maybe_convert_datelike_array() @@ -2591,43 +2613,43 @@ def _dt_as_unit(self, unit: str): return type(self)(pa.array(pd_array.as_unit(unit), from_pandas=True)) @property - def _dt_year(self): + def _dt_year(self) -> Self: return type(self)(pc.year(self._pa_array)) @property - def _dt_day(self): + def _dt_day(self) -> Self: return type(self)(pc.day(self._pa_array)) @property - def _dt_day_of_week(self): + def _dt_day_of_week(self) -> Self: return type(self)(pc.day_of_week(self._pa_array)) _dt_dayofweek = _dt_day_of_week _dt_weekday = _dt_day_of_week @property - def _dt_day_of_year(self): + def _dt_day_of_year(self) -> Self: return type(self)(pc.day_of_year(self._pa_array)) _dt_dayofyear = _dt_day_of_year @property - def _dt_hour(self): + def _dt_hour(self) -> Self: return type(self)(pc.hour(self._pa_array)) - def _dt_isocalendar(self): + def _dt_isocalendar(self) -> Self: return type(self)(pc.iso_calendar(self._pa_array)) @property - def _dt_is_leap_year(self): + def _dt_is_leap_year(self) -> Self: return type(self)(pc.is_leap_year(self._pa_array)) @property - def _dt_is_month_start(self): + def _dt_is_month_start(self) -> Self: return type(self)(pc.equal(pc.day(self._pa_array), 1)) @property - def _dt_is_month_end(self): + def _dt_is_month_end(self) -> Self: result = pc.equal( pc.days_between( pc.floor_temporal(self._pa_array, unit="day"), @@ -2638,7 +2660,7 @@ def _dt_is_month_end(self): return type(self)(result) @property - def _dt_is_year_start(self): + def _dt_is_year_start(self) -> Self: return type(self)( pc.and_( pc.equal(pc.month(self._pa_array), 1), @@ -2647,7 +2669,7 @@ def _dt_is_year_start(self): ) @property - def _dt_is_year_end(self): + def _dt_is_year_end(self) -> Self: return type(self)( pc.and_( pc.equal(pc.month(self._pa_array), 12), @@ -2656,7 +2678,7 @@ def _dt_is_year_end(self): ) @property - def _dt_is_quarter_start(self): + def _dt_is_quarter_start(self) -> Self: result = pc.equal( pc.floor_temporal(self._pa_array, unit="quarter"), pc.floor_temporal(self._pa_array, unit="day"), @@ -2664,7 +2686,7 @@ def _dt_is_quarter_start(self): return type(self)(result) @property - def _dt_is_quarter_end(self): + def _dt_is_quarter_end(self) -> Self: result = pc.equal( pc.days_between( pc.floor_temporal(self._pa_array, unit="day"), @@ -2675,7 +2697,7 @@ def _dt_is_quarter_end(self): return type(self)(result) @property - def _dt_days_in_month(self): + def _dt_days_in_month(self) -> Self: result = pc.days_between( pc.floor_temporal(self._pa_array, unit="month"), pc.ceil_temporal(self._pa_array, unit="month"), @@ -2685,35 +2707,35 @@ def _dt_days_in_month(self): _dt_daysinmonth = _dt_days_in_month @property - def _dt_microsecond(self): + def _dt_microsecond(self) -> Self: return type(self)(pc.microsecond(self._pa_array)) @property - def _dt_minute(self): + def _dt_minute(self) -> Self: return type(self)(pc.minute(self._pa_array)) @property - def _dt_month(self): + def _dt_month(self) -> Self: return type(self)(pc.month(self._pa_array)) @property - def _dt_nanosecond(self): + def _dt_nanosecond(self) -> Self: return type(self)(pc.nanosecond(self._pa_array)) @property - def _dt_quarter(self): + def _dt_quarter(self) -> Self: return type(self)(pc.quarter(self._pa_array)) @property - def _dt_second(self): + def _dt_second(self) -> Self: return type(self)(pc.second(self._pa_array)) @property - def _dt_date(self): + def _dt_date(self) -> Self: return type(self)(self._pa_array.cast(pa.date32())) @property - def _dt_time(self): + def _dt_time(self) -> Self: unit = ( self.dtype.pyarrow_dtype.unit if self.dtype.pyarrow_dtype.unit in {"us", "ns"} @@ -2729,10 +2751,10 @@ def _dt_tz(self): def _dt_unit(self): return self.dtype.pyarrow_dtype.unit - def _dt_normalize(self): + def _dt_normalize(self) -> Self: return type(self)(pc.floor_temporal(self._pa_array, 1, "day")) - def _dt_strftime(self, format: str): + def _dt_strftime(self, format: str) -> Self: return type(self)(pc.strftime(self._pa_array, format=format)) def _round_temporally( @@ -2741,7 +2763,7 @@ def _round_temporally( freq, ambiguous: TimeAmbiguous = "raise", nonexistent: TimeNonexistent = "raise", - ): + ) -> Self: if ambiguous != "raise": raise NotImplementedError("ambiguous is not supported.") if nonexistent != "raise": @@ -2777,7 +2799,7 @@ def _dt_ceil( freq, ambiguous: TimeAmbiguous = "raise", nonexistent: TimeNonexistent = "raise", - ): + ) -> Self: return self._round_temporally("ceil", freq, ambiguous, nonexistent) def _dt_floor( @@ -2785,7 +2807,7 @@ def _dt_floor( freq, ambiguous: TimeAmbiguous = "raise", nonexistent: TimeNonexistent = "raise", - ): + ) -> Self: return self._round_temporally("floor", freq, ambiguous, nonexistent) def _dt_round( @@ -2793,20 +2815,20 @@ def _dt_round( freq, ambiguous: TimeAmbiguous = "raise", nonexistent: TimeNonexistent = "raise", - ): + ) -> Self: return self._round_temporally("round", freq, ambiguous, nonexistent) - def _dt_day_name(self, locale: str | None = None): + def _dt_day_name(self, locale: str | None = None) -> Self: if locale is None: locale = "C" return type(self)(pc.strftime(self._pa_array, format="%A", locale=locale)) - def _dt_month_name(self, locale: str | None = None): + def _dt_month_name(self, locale: str | None = None) -> Self: if locale is None: locale = "C" return type(self)(pc.strftime(self._pa_array, format="%B", locale=locale)) - def _dt_to_pydatetime(self): + def _dt_to_pydatetime(self) -> np.ndarray: if pa.types.is_date(self.dtype.pyarrow_dtype): raise ValueError( f"to_pydatetime cannot be called with {self.dtype.pyarrow_dtype} type. " @@ -2822,7 +2844,7 @@ def _dt_tz_localize( tz, ambiguous: TimeAmbiguous = "raise", nonexistent: TimeNonexistent = "raise", - ): + ) -> Self: if ambiguous != "raise": raise NotImplementedError(f"{ambiguous=} is not supported") nonexistent_pa = { @@ -2842,7 +2864,7 @@ def _dt_tz_localize( ) return type(self)(result) - def _dt_tz_convert(self, tz): + def _dt_tz_convert(self, tz) -> Self: if self.dtype.pyarrow_dtype.tz is None: raise TypeError( "Cannot convert tz-naive timestamps, use tz_localize to localize" From 1fc6492c3695e306d3610089d552ad5d20abac5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sat, 30 Dec 2023 20:28:52 -0500 Subject: [PATCH 2/8] fix mypy stubtest --- pandas/_libs/hashtable_class_helper.pxi.in | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 988851ffe1ee1..ed1284c34e110 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -1180,7 +1180,7 @@ cdef class StringHashTable(HashTable): return uniques.to_array(), labels.base # .base -> underlying ndarray return uniques.to_array() - def unique(self, ndarray[object] values, bint return_inverse=False, object mask=None): + def unique(self, ndarray[object] values, *, bint return_inverse=False, object mask=None): """ Calculate unique values and labels (no sorting!) @@ -1438,7 +1438,7 @@ cdef class PyObjectHashTable(HashTable): return uniques.to_array(), labels.base # .base -> underlying ndarray return uniques.to_array() - def unique(self, ndarray[object] values, bint return_inverse=False, object mask=None): + def unique(self, ndarray[object] values, *, bint return_inverse=False, object mask=None): """ Calculate unique values and labels (no sorting!) From 6f887251ba6ac6ec9f79a9eb04fc18db5e5aeb73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sun, 31 Dec 2023 16:12:52 -0500 Subject: [PATCH 3/8] and return types for core.arrays --- pandas/_typing.py | 2 +- pandas/core/accessor.py | 2 +- pandas/core/arrays/arrow/extension_types.py | 2 +- pandas/core/arrays/base.py | 6 ++- pandas/core/arrays/categorical.py | 23 ++++++++--- pandas/core/arrays/datetimelike.py | 42 +++++++++------------ pandas/core/arrays/datetimes.py | 24 ++++++++---- pandas/core/arrays/interval.py | 13 +++++-- pandas/core/arrays/masked.py | 37 +++++++++++++++--- pandas/core/arrays/sparse/accessor.py | 11 +++++- pandas/core/arrays/sparse/array.py | 6 ++- pandas/core/arrays/string_.py | 10 +++-- pandas/core/arrays/string_arrow.py | 19 ++++++---- pandas/core/dtypes/cast.py | 35 +++++++++-------- pandas/core/internals/array_manager.py | 4 +- pandas/core/internals/managers.py | 12 +++--- pandas/core/strings/object_array.py | 8 ++-- pandas/core/tools/datetimes.py | 2 +- 18 files changed, 160 insertions(+), 98 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index 3df9a47a35fca..aed7aa7f2cc8e 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -137,7 +137,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[_T_co]: ... - def index(self, value: Any, /, start: int = 0, stop: int = ...) -> int: + def index(self, value: Any, start: int = ..., stop: int = ..., /) -> int: ... def count(self, value: Any, /) -> int: diff --git a/pandas/core/accessor.py b/pandas/core/accessor.py index 698abb2ec4989..9098a6f9664a9 100644 --- a/pandas/core/accessor.py +++ b/pandas/core/accessor.py @@ -54,7 +54,7 @@ class PandasDelegate: def _delegate_property_get(self, name: str, *args, **kwargs): raise TypeError(f"You cannot access the property {name}") - def _delegate_property_set(self, name: str, value, *args, **kwargs): + def _delegate_property_set(self, name: str, value, *args, **kwargs) -> None: raise TypeError(f"The property {name} cannot be set") def _delegate_method(self, name: str, *args, **kwargs): diff --git a/pandas/core/arrays/arrow/extension_types.py b/pandas/core/arrays/arrow/extension_types.py index d52b60df47adc..2fa5f7a882cc7 100644 --- a/pandas/core/arrays/arrow/extension_types.py +++ b/pandas/core/arrays/arrow/extension_types.py @@ -145,7 +145,7 @@ def patch_pyarrow() -> None: return class ForbiddenExtensionType(pyarrow.ExtensionType): - def __arrow_ext_serialize__(self): + def __arrow_ext_serialize__(self) -> bytes: return b"" @classmethod diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 59c6d911cfaef..b1a33c8fbd3a1 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -266,7 +266,9 @@ class ExtensionArray: # ------------------------------------------------------------------------ @classmethod - def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): + def _from_sequence( + cls, scalars, *, dtype: Dtype | None = None, copy: bool = False + ) -> Self: """ Construct a new ExtensionArray from a sequence of scalars. @@ -329,7 +331,7 @@ def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: @classmethod def _from_sequence_of_strings( cls, strings, *, dtype: Dtype | None = None, copy: bool = False - ): + ) -> Self: """ Construct a new ExtensionArray from a sequence of strings. diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 8a88227ad54a3..58809ba54ed56 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -597,7 +597,7 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike: return result - def to_list(self): + def to_list(self) -> list: """ Alias for tolist. """ @@ -1017,7 +1017,9 @@ def as_unordered(self) -> Self: """ return self.set_ordered(False) - def set_categories(self, new_categories, ordered=None, rename: bool = False): + def set_categories( + self, new_categories, ordered=None, rename: bool = False + ) -> Self: """ Set the categories to the specified new categories. @@ -1870,7 +1872,7 @@ def check_for_ordered(self, op) -> None: def argsort( self, *, ascending: bool = True, kind: SortKind = "quicksort", **kwargs - ): + ) -> npt.NDArray[np.intp]: """ Return the indices that would sort the Categorical. @@ -2618,7 +2620,15 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]: code_values = code_values[null_mask | (code_values >= 0)] return algorithms.isin(self.codes, code_values) - def _replace(self, *, to_replace, value, inplace: bool = False): + @overload + def _replace(self, *, to_replace, value, inplace: Literal[False] = ...) -> Self: + ... + + @overload + def _replace(self, *, to_replace, value, inplace: Literal[True]) -> None: + ... + + def _replace(self, *, to_replace, value, inplace: bool = False) -> Self | None: from pandas import Index orig_dtype = self.dtype @@ -2666,6 +2676,7 @@ def _replace(self, *, to_replace, value, inplace: bool = False): ) if not inplace: return cat + return None # ------------------------------------------------------------------------ # String methods interface @@ -2901,8 +2912,8 @@ def _delegate_property_get(self, name: str): # error: Signature of "_delegate_property_set" incompatible with supertype # "PandasDelegate" - def _delegate_property_set(self, name: str, new_values): # type: ignore[override] - return setattr(self._parent, name, new_values) + def _delegate_property_set(self, name: str, new_values) -> None: # type: ignore[override] + setattr(self._parent, name, new_values) @property def codes(self) -> Series: diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 6ca74c4c05bc6..4668db8d75cd7 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -344,7 +344,7 @@ def _format_native_types( """ raise AbstractMethodError(self) - def _formatter(self, boxed: bool = False): + def _formatter(self, boxed: bool = False) -> Callable[[object], str]: # TODO: Remove Datetime & DatetimeTZ formatters. return "'{}'".format @@ -808,9 +808,8 @@ def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]: if self.dtype.kind in "mM": self = cast("DatetimeArray | TimedeltaArray", self) - # error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]" - # has no attribute "as_unit" - values = values.as_unit(self.unit) # type: ignore[union-attr] + # error: "DatetimeLikeArrayMixin" has no attribute "as_unit" + values = values.as_unit(self.unit) # type: ignore[attr-defined] try: # error: Argument 1 to "_check_compatible_with" of "DatetimeLikeArrayMixin" @@ -1209,7 +1208,7 @@ def _add_timedeltalike_scalar(self, other): self, other = self._ensure_matching_resos(other) return self._add_timedeltalike(other) - def _add_timedelta_arraylike(self, other: TimedeltaArray): + def _add_timedelta_arraylike(self, other: TimedeltaArray) -> Self: """ Add a delta of a TimedeltaIndex @@ -1222,30 +1221,26 @@ def _add_timedelta_arraylike(self, other: TimedeltaArray): if len(self) != len(other): raise ValueError("cannot add indices of unequal length") - self = cast("DatetimeArray | TimedeltaArray", self) - - self, other = self._ensure_matching_resos(other) + self, other = cast( + "DatetimeArray | TimedeltaArray", self + )._ensure_matching_resos(other) return self._add_timedeltalike(other) @final - def _add_timedeltalike(self, other: Timedelta | TimedeltaArray): - self = cast("DatetimeArray | TimedeltaArray", self) - + def _add_timedeltalike(self, other: Timedelta | TimedeltaArray) -> Self: other_i8, o_mask = self._get_i8_values_and_mask(other) new_values = add_overflowsafe(self.asi8, np.asarray(other_i8, dtype="i8")) res_values = new_values.view(self._ndarray.dtype) new_freq = self._get_arithmetic_result_freq(other) - # error: Argument "dtype" to "_simple_new" of "DatetimeArray" has - # incompatible type "Union[dtype[datetime64], DatetimeTZDtype, - # dtype[timedelta64]]"; expected "Union[dtype[datetime64], DatetimeTZDtype]" + # error: Unexpected keyword argument "freq" for "_simple_new" of "NDArrayBacked" return type(self)._simple_new( - res_values, dtype=self.dtype, freq=new_freq # type: ignore[arg-type] + res_values, dtype=self.dtype, freq=new_freq # type: ignore[call-arg] ) @final - def _add_nat(self): + def _add_nat(self) -> Self: """ Add pd.NaT to self """ @@ -1253,22 +1248,19 @@ def _add_nat(self): raise TypeError( f"Cannot add {type(self).__name__} and {type(NaT).__name__}" ) - self = cast("TimedeltaArray | DatetimeArray", self) # GH#19124 pd.NaT is treated like a timedelta for both timedelta # and datetime dtypes result = np.empty(self.shape, dtype=np.int64) result.fill(iNaT) result = result.view(self._ndarray.dtype) # preserve reso - # error: Argument "dtype" to "_simple_new" of "DatetimeArray" has - # incompatible type "Union[dtype[timedelta64], dtype[datetime64], - # DatetimeTZDtype]"; expected "Union[dtype[datetime64], DatetimeTZDtype]" + # error: Unexpected keyword argument "freq" for "_simple_new" of "NDArrayBacked" return type(self)._simple_new( - result, dtype=self.dtype, freq=None # type: ignore[arg-type] + result, dtype=self.dtype, freq=None # type: ignore[call-arg] ) @final - def _sub_nat(self): + def _sub_nat(self) -> np.ndarray: """ Subtract pd.NaT from self """ @@ -1313,7 +1305,7 @@ def _sub_periodlike(self, other: Period | PeriodArray) -> npt.NDArray[np.object_ return new_data @final - def _addsub_object_array(self, other: npt.NDArray[np.object_], op): + def _addsub_object_array(self, other: npt.NDArray[np.object_], op) -> np.ndarray: """ Add or subtract array-like of DateOffset objects @@ -1364,7 +1356,7 @@ def __add__(self, other): # scalar others if other is NaT: - result = self._add_nat() + result: np.ndarray | DatetimeLikeArrayMixin = self._add_nat() elif isinstance(other, (Tick, timedelta, np.timedelta64)): result = self._add_timedeltalike_scalar(other) elif isinstance(other, BaseOffset): @@ -1424,7 +1416,7 @@ def __sub__(self, other): # scalar others if other is NaT: - result = self._sub_nat() + result: np.ndarray | DatetimeLikeArrayMixin = self._sub_nat() elif isinstance(other, (Tick, timedelta, np.timedelta64)): result = self._add_timedeltalike_scalar(-other) elif isinstance(other, BaseOffset): diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 6b7ddc4a72957..de395a9b8c0a3 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -7,6 +7,8 @@ ) from typing import ( TYPE_CHECKING, + Generator, + TypeVar, cast, overload, ) @@ -86,9 +88,15 @@ npt, ) - from pandas import DataFrame + from pandas import ( + DataFrame, + Timedelta, + ) from pandas.core.arrays import PeriodArray + _TimestampNoneT1 = TypeVar("_TimestampNoneT1", Timestamp, None) + _TimestampNoneT2 = TypeVar("_TimestampNoneT2", Timestamp, None) + _ITER_CHUNKSIZE = 10_000 @@ -326,7 +334,7 @@ def _simple_new( # type: ignore[override] return result @classmethod - def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False): + def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self: return cls._from_sequence_not_strict(scalars, dtype=dtype, copy=copy) @classmethod @@ -2125,7 +2133,7 @@ def std( ddof: int = 1, keepdims: bool = False, skipna: bool = True, - ): + ) -> Timedelta: """ Return sample standard deviation over requested axis. @@ -2191,7 +2199,7 @@ def _sequence_to_dt64( yearfirst: bool = False, ambiguous: TimeAmbiguous = "raise", out_unit: str | None = None, -): +) -> tuple[np.ndarray, tzinfo | None]: """ Parameters ---------- @@ -2360,7 +2368,7 @@ def objects_to_datetime64( errors: DateTimeErrorChoices = "raise", allow_object: bool = False, out_unit: str = "ns", -): +) -> tuple[np.ndarray, tzinfo | None]: """ Convert data to array of timestamps. @@ -2665,8 +2673,8 @@ def _infer_tz_from_endpoints( def _maybe_normalize_endpoints( - start: Timestamp | None, end: Timestamp | None, normalize: bool -): + start: _TimestampNoneT1, end: _TimestampNoneT2, normalize: bool +) -> tuple[_TimestampNoneT1, _TimestampNoneT2]: if normalize: if start is not None: start = start.normalize() @@ -2717,7 +2725,7 @@ def _generate_range( offset: BaseOffset, *, unit: str, -): +) -> Generator[Timestamp, None, None]: """ Generates a sequence of dates corresponding to the specified time offset. Similar to dateutil.rrule except uses pandas DateOffset diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index a19b304529383..7d2d98f71b38c 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -8,6 +8,7 @@ import textwrap from typing import ( TYPE_CHECKING, + Callable, Literal, Union, overload, @@ -232,7 +233,7 @@ def __new__( dtype: Dtype | None = None, copy: bool = False, verify_integrity: bool = True, - ): + ) -> Self: data = extract_array(data, extract_numpy=True) if isinstance(data, cls): @@ -1241,7 +1242,7 @@ def value_counts(self, dropna: bool = True) -> Series: # --------------------------------------------------------------------- # Rendering Methods - def _formatter(self, boxed: bool = False): + def _formatter(self, boxed: bool = False) -> Callable[[object], str]: # returning 'str' here causes us to render as e.g. "(0, 1]" instead of # "Interval(0, 1, closed='right')" return str @@ -1842,9 +1843,13 @@ def _from_combined(self, combined: np.ndarray) -> IntervalArray: dtype = self._left.dtype if needs_i8_conversion(dtype): assert isinstance(self._left, (DatetimeArray, TimedeltaArray)) - new_left = type(self._left)._from_sequence(nc[:, 0], dtype=dtype) + new_left: DatetimeArray | TimedeltaArray | np.ndarray = type( + self._left + )._from_sequence(nc[:, 0], dtype=dtype) assert isinstance(self._right, (DatetimeArray, TimedeltaArray)) - new_right = type(self._right)._from_sequence(nc[:, 1], dtype=dtype) + new_right: DatetimeArray | TimedeltaArray | np.ndarray = type( + self._right + )._from_sequence(nc[:, 1], dtype=dtype) else: assert isinstance(dtype, np.dtype) new_left = nc[:, 0].view(dtype) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index 03c09c5b2fd18..dc0b5b60fc4a5 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -98,6 +98,7 @@ NumpySorter, NumpyValueArrayLike, ) + from pandas._libs.missing import NAType from pandas.compat.numpy import function as nv @@ -152,7 +153,7 @@ def _from_sequence(cls, scalars, *, dtype=None, copy: bool = False) -> Self: @classmethod @doc(ExtensionArray._empty) - def _empty(cls, shape: Shape, dtype: ExtensionDtype): + def _empty(cls, shape: Shape, dtype: ExtensionDtype) -> Self: values = np.empty(shape, dtype=dtype.type) values.fill(cls._internal_fill_value) mask = np.ones(shape, dtype=bool) @@ -499,7 +500,7 @@ def to_numpy( return data @doc(ExtensionArray.tolist) - def tolist(self): + def tolist(self) -> list: if self.ndim > 1: return [x.tolist() for x in self] dtype = None if self._hasna else self._data.dtype @@ -1307,7 +1308,21 @@ def max(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): def map(self, mapper, na_action=None): return map_array(self.to_numpy(), mapper, na_action=None) - def any(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): + @overload + def any( + self, *, skipna: Literal[True] = ..., axis: AxisInt | None = ..., **kwargs + ) -> bool: + ... + + @overload + def any( + self, *, skipna: bool, axis: AxisInt | None = ..., **kwargs + ) -> bool | NAType: + ... + + def any( + self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs + ) -> bool | NAType: """ Return whether any element is truthy. @@ -1379,7 +1394,7 @@ def any(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): # bool, int, float, complex, str, bytes, # _NestedSequence[Union[bool, int, float, complex, str, bytes]]]" np.putmask(values, self._mask, self._falsey_value) # type: ignore[arg-type] - result = values.any() + result = values.any().item() if skipna: return result else: @@ -1388,6 +1403,18 @@ def any(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): else: return self.dtype.na_value + @overload + def all( + self, *, skipna: Literal[True] = ..., axis: AxisInt | None = ..., **kwargs + ) -> bool: + ... + + @overload + def all( + self, *, skipna: bool, axis: AxisInt | None = ..., **kwargs + ) -> bool | NAType: + ... + def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): """ Return whether all elements are truthy. @@ -1460,7 +1487,7 @@ def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): # bool, int, float, complex, str, bytes, # _NestedSequence[Union[bool, int, float, complex, str, bytes]]]" np.putmask(values, self._mask, self._truthy_value) # type: ignore[arg-type] - result = values.all(axis=axis) + result = values.all(axis=axis).item() if skipna: return result diff --git a/pandas/core/arrays/sparse/accessor.py b/pandas/core/arrays/sparse/accessor.py index 3dd7ebf564ca1..a1d81aeeecb0b 100644 --- a/pandas/core/arrays/sparse/accessor.py +++ b/pandas/core/arrays/sparse/accessor.py @@ -17,6 +17,11 @@ from pandas.core.arrays.sparse.array import SparseArray if TYPE_CHECKING: + from scipy.sparse import ( + coo_matrix, + spmatrix, + ) + from pandas import ( DataFrame, Series, @@ -115,7 +120,9 @@ def from_coo(cls, A, dense_index: bool = False) -> Series: return result - def to_coo(self, row_levels=(0,), column_levels=(1,), sort_labels: bool = False): + def to_coo( + self, row_levels=(0,), column_levels=(1,), sort_labels: bool = False + ) -> tuple[coo_matrix, list, list]: """ Create a scipy.sparse.coo_matrix from a Series with MultiIndex. @@ -326,7 +333,7 @@ def to_dense(self) -> DataFrame: data = {k: v.array.to_dense() for k, v in self._parent.items()} return DataFrame(data, index=self._parent.index, columns=self._parent.columns) - def to_coo(self): + def to_coo(self) -> spmatrix: """ Return the contents of the frame as a sparse SciPy COO matrix. diff --git a/pandas/core/arrays/sparse/array.py b/pandas/core/arrays/sparse/array.py index 7a3ea85dde2b4..db670e1ea4816 100644 --- a/pandas/core/arrays/sparse/array.py +++ b/pandas/core/arrays/sparse/array.py @@ -584,11 +584,13 @@ def __setitem__(self, key, value) -> None: raise TypeError(msg) @classmethod - def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): + def _from_sequence( + cls, scalars, *, dtype: Dtype | None = None, copy: bool = False + ) -> Self: return cls(scalars, dtype=dtype) @classmethod - def _from_factorized(cls, values, original): + def _from_factorized(cls, values, original) -> Self: return cls(values, dtype=original.dtype) # ------------------------------------------------------------------------ diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index f451ebc352733..d4da5840689de 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -257,7 +257,7 @@ class BaseStringArray(ExtensionArray): """ @doc(ExtensionArray.tolist) - def tolist(self): + def tolist(self) -> list: if self.ndim > 1: return [x.tolist() for x in self] return list(self.to_numpy()) @@ -381,7 +381,9 @@ def _validate(self) -> None: lib.convert_nans_to_NA(self._ndarray) @classmethod - def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): + def _from_sequence( + cls, scalars, *, dtype: Dtype | None = None, copy: bool = False + ) -> Self: if dtype and not (isinstance(dtype, str) and dtype == "string"): dtype = pandas_dtype(dtype) assert isinstance(dtype, StringDtype) and dtype.storage == "python" @@ -414,7 +416,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal @classmethod def _from_sequence_of_strings( cls, strings, *, dtype: Dtype | None = None, copy: bool = False - ): + ) -> Self: return cls._from_sequence(strings, dtype=dtype, copy=copy) @classmethod @@ -436,7 +438,7 @@ def __arrow_array__(self, type=None): values[self.isna()] = None return pa.array(values, type=type, from_pandas=True) - def _values_for_factorize(self): + def _values_for_factorize(self) -> tuple[np.ndarray, None]: arr = self._ndarray.copy() mask = self.isna() arr[mask] = None diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index d5a76811a12e6..cb07fcf1a48fa 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -59,6 +59,7 @@ AxisInt, Dtype, Scalar, + Self, npt, ) @@ -172,7 +173,9 @@ def __len__(self) -> int: return len(self._pa_array) @classmethod - def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = False): + def _from_sequence( + cls, scalars, *, dtype: Dtype | None = None, copy: bool = False + ) -> Self: from pandas.core.arrays.masked import BaseMaskedArray _chk_pyarrow_available() @@ -201,7 +204,7 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal @classmethod def _from_sequence_of_strings( cls, strings, dtype: Dtype | None = None, copy: bool = False - ): + ) -> Self: return cls._from_sequence(strings, dtype=dtype, copy=copy) @property @@ -439,7 +442,7 @@ def _str_fullmatch( def _str_slice( self, start: int | None = None, stop: int | None = None, step: int | None = None - ): + ) -> Self: if stop is None: return super()._str_slice(start, stop, step) if start is None: @@ -490,27 +493,27 @@ def _str_len(self): result = pc.utf8_length(self._pa_array) return self._convert_int_dtype(result) - def _str_lower(self): + def _str_lower(self) -> Self: return type(self)(pc.utf8_lower(self._pa_array)) - def _str_upper(self): + def _str_upper(self) -> Self: return type(self)(pc.utf8_upper(self._pa_array)) - def _str_strip(self, to_strip=None): + def _str_strip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_trim_whitespace(self._pa_array) else: result = pc.utf8_trim(self._pa_array, characters=to_strip) return type(self)(result) - def _str_lstrip(self, to_strip=None): + def _str_lstrip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_ltrim_whitespace(self._pa_array) else: result = pc.utf8_ltrim(self._pa_array, characters=to_strip) return type(self)(result) - def _str_rstrip(self, to_strip=None): + def _str_rstrip(self, to_strip=None) -> Self: if to_strip is None: result = pc.utf8_rtrim_whitespace(self._pa_array) else: diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index d7a177c2a19c0..9a1ec2330a326 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1541,25 +1541,24 @@ def construct_1d_arraylike_from_scalar( if isinstance(dtype, ExtensionDtype): cls = dtype.construct_array_type() seq = [] if length == 0 else [value] - subarr = cls._from_sequence(seq, dtype=dtype).repeat(length) + return cls._from_sequence(seq, dtype=dtype).repeat(length) + + if length and dtype.kind in "iu" and isna(value): + # coerce if we have nan for an integer dtype + dtype = np.dtype("float64") + elif lib.is_np_dtype(dtype, "US"): + # we need to coerce to object dtype to avoid + # to allow numpy to take our string as a scalar value + dtype = np.dtype("object") + if not isna(value): + value = ensure_str(value) + elif dtype.kind in "mM": + value = _maybe_box_and_unbox_datetimelike(value, dtype) - else: - if length and dtype.kind in "iu" and isna(value): - # coerce if we have nan for an integer dtype - dtype = np.dtype("float64") - elif lib.is_np_dtype(dtype, "US"): - # we need to coerce to object dtype to avoid - # to allow numpy to take our string as a scalar value - dtype = np.dtype("object") - if not isna(value): - value = ensure_str(value) - elif dtype.kind in "mM": - value = _maybe_box_and_unbox_datetimelike(value, dtype) - - subarr = np.empty(length, dtype=dtype) - if length: - # GH 47391: numpy > 1.24 will raise filling np.nan into int dtypes - subarr.fill(value) + subarr = np.empty(length, dtype=dtype) + if length: + # GH 47391: numpy > 1.24 will raise filling np.nan into int dtypes + subarr.fill(value) return subarr diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index e253f82256a5f..ee62441ab8f55 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -661,7 +661,9 @@ def fast_xs(self, loc: int) -> SingleArrayManager: values = [arr[loc] for arr in self.arrays] if isinstance(dtype, ExtensionDtype): - result = dtype.construct_array_type()._from_sequence(values, dtype=dtype) + result: np.ndarray | ExtensionArray = ( + dtype.construct_array_type()._from_sequence(values, dtype=dtype) + ) # for datetime64/timedelta64, the np.ndarray constructor cannot handle pd.NaT elif is_datetime64_ns_dtype(dtype): result = DatetimeArray._from_sequence(values, dtype=dtype)._ndarray diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index 5f38720135efa..d08dee3663395 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -971,7 +971,9 @@ def fast_xs(self, loc: int) -> SingleBlockManager: if len(self.blocks) == 1: # TODO: this could be wrong if blk.mgr_locs is not slice(None)-like; # is this ruled out in the general case? - result = self.blocks[0].iget((slice(None), loc)) + result: np.ndarray | ExtensionArray = self.blocks[0].iget( + (slice(None), loc) + ) # in the case of a single block, the new block is a view bp = BlockPlacement(slice(0, len(result))) block = new_block( @@ -2368,9 +2370,9 @@ def make_na_array(dtype: DtypeObj, shape: Shape, fill_value) -> ArrayLike: else: # NB: we should never get here with dtype integer or bool; # if we did, the missing_arr.fill would cast to gibberish - missing_arr = np.empty(shape, dtype=dtype) - missing_arr.fill(fill_value) + missing_arr_np = np.empty(shape, dtype=dtype) + missing_arr_np.fill(fill_value) if dtype.kind in "mM": - missing_arr = ensure_wrapped_if_datetimelike(missing_arr) - return missing_arr + missing_arr_np = ensure_wrapped_if_datetimelike(missing_arr_np) + return missing_arr_np diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 0029beccc40a8..29d17e7174ee9 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -205,10 +205,10 @@ def rep(x, r): np.asarray(repeats, dtype=object), rep, ) - if isinstance(self, BaseStringArray): - # Not going through map, so we have to do this here. - result = type(self)._from_sequence(result, dtype=self.dtype) - return result + if not isinstance(self, BaseStringArray): + return result + # Not going through map, so we have to do this here. + return type(self)._from_sequence(result, dtype=self.dtype) def _str_match( self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None diff --git a/pandas/core/tools/datetimes.py b/pandas/core/tools/datetimes.py index 05262c235568d..097765f5705af 100644 --- a/pandas/core/tools/datetimes.py +++ b/pandas/core/tools/datetimes.py @@ -445,7 +445,7 @@ def _convert_listlike_datetimes( # We can take a shortcut since the datetime64 numpy array # is in UTC out_unit = np.datetime_data(result.dtype)[0] - dtype = cast(DatetimeTZDtype, tz_to_dtype(tz_parsed, out_unit)) + dtype = tz_to_dtype(tz_parsed, out_unit) dt64_values = result.view(f"M8[{dtype.unit}]") dta = DatetimeArray._simple_new(dt64_values, dtype=dtype) return DatetimeIndex._simple_new(dta, name=name) From 1e97bde66ec3725e6d88ac76ccb7cdf282d1b36c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sun, 31 Dec 2023 16:21:13 -0500 Subject: [PATCH 4/8] pyupgrade --- pandas/core/arrays/datetimes.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index de395a9b8c0a3..a4d01dd6667f6 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -7,7 +7,6 @@ ) from typing import ( TYPE_CHECKING, - Generator, TypeVar, cast, overload, @@ -75,7 +74,10 @@ ) if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import ( + Generator, + Iterator, + ) from pandas._typing import ( ArrayLike, From 4d287b0a48c66c94dc750c14d4c4b7c49f39a90a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Sun, 31 Dec 2023 17:12:36 -0500 Subject: [PATCH 5/8] runtime actually expectes np.bool_ (calls .reshape(1) on it) --- pandas/core/arrays/masked.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index dc0b5b60fc4a5..c1bac9cfcb02f 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -1311,18 +1311,18 @@ def map(self, mapper, na_action=None): @overload def any( self, *, skipna: Literal[True] = ..., axis: AxisInt | None = ..., **kwargs - ) -> bool: + ) -> np.bool_: ... @overload def any( self, *, skipna: bool, axis: AxisInt | None = ..., **kwargs - ) -> bool | NAType: + ) -> np.bool_ | NAType: ... def any( self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs - ) -> bool | NAType: + ) -> np.bool_ | NAType: """ Return whether any element is truthy. @@ -1394,7 +1394,7 @@ def any( # bool, int, float, complex, str, bytes, # _NestedSequence[Union[bool, int, float, complex, str, bytes]]]" np.putmask(values, self._mask, self._falsey_value) # type: ignore[arg-type] - result = values.any().item() + result = values.any() if skipna: return result else: @@ -1406,16 +1406,18 @@ def any( @overload def all( self, *, skipna: Literal[True] = ..., axis: AxisInt | None = ..., **kwargs - ) -> bool: + ) -> np.bool_: ... @overload def all( self, *, skipna: bool, axis: AxisInt | None = ..., **kwargs - ) -> bool | NAType: + ) -> np.bool_ | NAType: ... - def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): + def all( + self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs + ) -> np.bool_ | NAType: """ Return whether all elements are truthy. @@ -1487,7 +1489,7 @@ def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs): # bool, int, float, complex, str, bytes, # _NestedSequence[Union[bool, int, float, complex, str, bytes]]]" np.putmask(values, self._mask, self._truthy_value) # type: ignore[arg-type] - result = values.all(axis=axis).item() + result = values.all(axis=axis) if skipna: return result From 2e93dd5dd8d49d3b37594633683f49409ad99b5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Tue, 2 Jan 2024 10:11:58 -0500 Subject: [PATCH 6/8] TypeVar --- pandas/_typing.py | 1 + pandas/core/algorithms.py | 24 ++++++------------------ 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index aed7aa7f2cc8e..a80f9603493a7 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -109,6 +109,7 @@ # array-like ArrayLike = Union["ExtensionArray", np.ndarray] +ArrayLikeT = TypeVar("ArrayLikeT", "ExtensionArray", np.ndarray) AnyArrayLike = Union[ArrayLike, "Index", "Series"] TimeArrayLike = Union["DatetimeArray", "TimedeltaArray"] diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 6f1ac75cebc0b..76fdcefd03407 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -11,7 +11,6 @@ TYPE_CHECKING, Literal, cast, - overload, ) import warnings @@ -26,6 +25,7 @@ from pandas._typing import ( AnyArrayLike, ArrayLike, + ArrayLikeT, AxisInt, DtypeObj, TakeIndexer, @@ -182,23 +182,9 @@ def _ensure_data(values: ArrayLike) -> np.ndarray: return ensure_object(values) -@overload def _reconstruct_data( - values: ExtensionArray, dtype: DtypeObj, original: AnyArrayLike -) -> ExtensionArray: - ... - - -@overload -def _reconstruct_data( - values: np.ndarray, dtype: DtypeObj, original: AnyArrayLike -) -> np.ndarray: - ... - - -def _reconstruct_data( - values: ArrayLike, dtype: DtypeObj, original: AnyArrayLike -) -> ArrayLike: + values: ArrayLikeT, dtype: DtypeObj, original: AnyArrayLike +) -> ArrayLikeT: """ reverse of _ensure_data @@ -221,7 +207,9 @@ def _reconstruct_data( # that values.dtype == dtype cls = dtype.construct_array_type() - values = cls._from_sequence(values, dtype=dtype) + # error: Incompatible types in assignment (expression has type + # "ExtensionArray", variable has type "ndarray[Any, Any]") + values = cls._from_sequence(values, dtype=dtype) # type: ignore[assignment] else: values = values.astype(dtype, copy=False) From df1a348529f84c38535c3c0b14bd79692c17d145 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Tue, 2 Jan 2024 15:12:27 -0500 Subject: [PATCH 7/8] return bool | NAType --- pandas/core/arrays/arrow/array.py | 8 ++------ pandas/core/arrays/base.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 464b4d6f5c36b..d7c4d695e6951 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -829,9 +829,7 @@ def isna(self) -> npt.NDArray[np.bool_]: return self._pa_array.is_null().to_numpy() - # error: Signature of "any" incompatible with supertype - # "ExtensionArraySupportsAnyAll" - @overload # type: ignore[override] + @overload def any(self, *, skipna: Literal[True] = ..., **kwargs) -> bool: ... @@ -897,9 +895,7 @@ def any(self, *, skipna: bool = True, **kwargs) -> bool | NAType: """ return self._reduce("any", skipna=skipna, **kwargs) - # error: Signature of "all" incompatible with supertype - # "ExtensionArraySupportsAnyAll" - @overload # type: ignore[override] + @overload def all(self, *, skipna: Literal[True] = ..., **kwargs) -> bool: ... diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index b1a33c8fbd3a1..ece610a4d65ad 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -80,6 +80,7 @@ Iterator, Sequence, ) + from pandas._libs.missing import NAType from pandas._typing import ( ArrayLike, @@ -2387,10 +2388,26 @@ def _groupby_op( class ExtensionArraySupportsAnyAll(ExtensionArray): - def any(self, *, skipna: bool = True) -> bool: + @overload + def any(self, *, skipna: Literal[True] = ...) -> bool: + ... + + @overload + def any(self, *, skipna: bool) -> bool | NAType: + ... + + def any(self, *, skipna: bool = True) -> bool | NAType: raise AbstractMethodError(self) - def all(self, *, skipna: bool = True) -> bool: + @overload + def all(self, *, skipna: Literal[True] = ...) -> bool: + ... + + @overload + def all(self, *, skipna: bool) -> bool | NAType: + ... + + def all(self, *, skipna: bool = True) -> bool | NAType: raise AbstractMethodError(self) From 9db839e55407ab1483b84ab80a73f3f28c6b473a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torsten=20W=C3=B6rtwein?= Date: Tue, 2 Jan 2024 15:19:18 -0500 Subject: [PATCH 8/8] isort --- pandas/core/arrays/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index ece610a4d65ad..e530b28cba88a 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -80,8 +80,8 @@ Iterator, Sequence, ) - from pandas._libs.missing import NAType + from pandas._libs.missing import NAType from pandas._typing import ( ArrayLike, AstypeArg,