Skip to content

Commit

Permalink
REFACTOR-#7260: Use extract_dtype internal function in more places (#…
Browse files Browse the repository at this point in the history
…7261)

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored May 14, 2024
1 parent 1c0d9a6 commit deddc14
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 63 deletions.
90 changes: 40 additions & 50 deletions modin/core/dataframe/pandas/metadata/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

"""Module contains class ``ModinDtypes``."""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Optional, Union

import numpy as np
import pandas
from pandas._typing import IndexLabel
from pandas._typing import DtypeObj, IndexLabel
from pandas.core.dtypes.cast import find_common_type

if TYPE_CHECKING:
Expand All @@ -33,13 +34,13 @@ class DtypesDescriptor:
Parameters
----------
known_dtypes : dict[IndexLabel, np.dtype] or pandas.Series, optional
known_dtypes : dict[IndexLabel, DtypeObj] or pandas.Series, optional
Columns that we know dtypes for.
cols_with_unknown_dtypes : list[IndexLabel], optional
Column names that have unknown dtypes. If specified together with `remaining_dtype`, must describe all
columns with unknown dtypes, otherwise, the missing columns will be assigned to `remaining_dtype`.
If `cols_with_unknown_dtypes` is incomplete, you must specify `know_all_names=False`.
remaining_dtype : np.dtype, optional
remaining_dtype : DtypeObj, optional
Dtype for columns that are not present neither in `known_dtypes` nor in `cols_with_unknown_dtypes`.
This parameter is intended to describe columns that we known dtypes for, but don't know their
names yet. Note, that this parameter DOESN'T describe dtypes for columns from `cols_with_unknown_dtypes`.
Expand All @@ -60,10 +61,10 @@ class DtypesDescriptor:

def __init__(
self,
known_dtypes: Optional[Union[dict[IndexLabel, np.dtype], pandas.Series]] = None,
known_dtypes: Optional[Union[dict[IndexLabel, DtypeObj], pandas.Series]] = None,
cols_with_unknown_dtypes: Optional[list[IndexLabel]] = None,
remaining_dtype: Optional[np.dtype] = None,
parent_df: Optional["PandasDataframe"] = None,
remaining_dtype: Optional[DtypeObj] = None,
parent_df: Optional[PandasDataframe] = None,
columns_order: Optional[dict[int, IndexLabel]] = None,
know_all_names: bool = True,
_schema_is_known: Optional[bool] = None,
Expand All @@ -73,7 +74,7 @@ def __init__(
"It's not allowed to pass 'remaining_dtype' and 'know_all_names=False' at the same time."
)
# columns with known dtypes
self._known_dtypes: dict[IndexLabel, np.dtype] = (
self._known_dtypes: dict[IndexLabel, DtypeObj] = (
{} if known_dtypes is None else dict(known_dtypes)
)
if known_dtypes is not None and len(self._known_dtypes) != len(known_dtypes):
Expand Down Expand Up @@ -106,8 +107,8 @@ def __init__(

self._know_all_names: bool = know_all_names
# a common dtype for columns that are not present in 'known_dtypes' nor in 'cols_with_unknown_dtypes'
self._remaining_dtype: Optional[np.dtype] = remaining_dtype
self._parent_df: Optional["PandasDataframe"] = parent_df
self._remaining_dtype: Optional[DtypeObj] = remaining_dtype
self._parent_df: Optional[PandasDataframe] = parent_df
if columns_order is None:
self._columns_order: Optional[dict[int, IndexLabel]] = None
# try to compute '._columns_order' using 'parent_df'
Expand All @@ -132,7 +133,7 @@ def __init__(
)
self._columns_order: Optional[dict[int, IndexLabel]] = columns_order

def update_parent(self, new_parent: "PandasDataframe"):
def update_parent(self, new_parent: PandasDataframe):
"""
Set new parent dataframe.
Expand Down Expand Up @@ -202,7 +203,7 @@ def __str__(self): # noqa: GL08

def lazy_get(
self, ids: list[Union[IndexLabel, int]], numeric_index: bool = False
) -> "DtypesDescriptor":
) -> DtypesDescriptor:
"""
Get dtypes descriptor for a subset of columns without triggering any computations.
Expand Down Expand Up @@ -255,7 +256,7 @@ def lazy_get(
columns_order=columns_order,
)

def copy(self) -> "DtypesDescriptor":
def copy(self) -> DtypesDescriptor:
"""
Get a copy of this descriptor.
Expand All @@ -279,7 +280,7 @@ def copy(self) -> "DtypesDescriptor":

def set_index(
self, new_index: Union[pandas.Index, "ModinIndex"]
) -> "DtypesDescriptor":
) -> DtypesDescriptor:
"""
Set new column names for this descriptor.
Expand Down Expand Up @@ -324,7 +325,7 @@ def set_index(
}
return new_self

def equals(self, other: "DtypesDescriptor") -> bool:
def equals(self, other: DtypesDescriptor) -> bool:
"""
Compare two descriptors for equality.
Expand Down Expand Up @@ -441,25 +442,25 @@ def to_series(self) -> pandas.Series:
self.materialize()
return pandas.Series(self._known_dtypes)

def get_dtypes_set(self) -> set[np.dtype]:
def get_dtypes_set(self) -> set[DtypeObj]:
"""
Get a set of dtypes from the descriptor.
Returns
-------
set[np.dtype]
set[DtypeObj]
"""
if len(self._cols_with_unknown_dtypes) > 0 or not self._know_all_names:
self._materialize_cols_with_unknown_dtypes()
known_dtypes: set[np.dtype] = set(self._known_dtypes.values())
known_dtypes: set[DtypeObj] = set(self._known_dtypes.values())
if self._remaining_dtype is not None:
known_dtypes.add(self._remaining_dtype)
return known_dtypes

@classmethod
def _merge_dtypes(
cls, values: list[Union["DtypesDescriptor", pandas.Series, None]]
) -> "DtypesDescriptor":
cls, values: list[Union[DtypesDescriptor, pandas.Series, None]]
) -> DtypesDescriptor:
"""
Union columns described by ``values`` and compute common dtypes for them.
Expand Down Expand Up @@ -551,8 +552,8 @@ def combine_dtypes(row):

@classmethod
def concat(
cls, values: list[Union["DtypesDescriptor", pandas.Series, None]], axis: int = 0
) -> "DtypesDescriptor":
cls, values: list[Union[DtypesDescriptor, pandas.Series, None]], axis: int = 0
) -> DtypesDescriptor:
"""
Concatenate dtypes descriptors into a single descriptor.
Expand Down Expand Up @@ -746,9 +747,7 @@ class ModinDtypes:

def __init__(
self,
value: Optional[
Union[Callable, pandas.Series, DtypesDescriptor, "ModinDtypes"]
],
value: Optional[Union[Callable, pandas.Series, DtypesDescriptor, ModinDtypes]],
):
if callable(value) or isinstance(value, pandas.Series):
self._value = value
Expand Down Expand Up @@ -778,23 +777,21 @@ def is_materialized(self) -> bool:
"""
return isinstance(self._value, pandas.Series)

def get_dtypes_set(self) -> set[np.dtype]:
def get_dtypes_set(self) -> set[DtypeObj]:
"""
Get a set of dtypes from the descriptor.
Returns
-------
set[np.dtype]
set[DtypeObj]
"""
if isinstance(self._value, DtypesDescriptor):
return self._value.get_dtypes_set()
if not self.is_materialized:
self.get()
return set(self._value.values)

def maybe_specify_new_frame_ref(
self, new_parent: "PandasDataframe"
) -> "ModinDtypes":
def maybe_specify_new_frame_ref(self, new_parent: PandasDataframe) -> ModinDtypes:
"""
Set a new parent for the stored value if needed.
Expand All @@ -816,7 +813,7 @@ def maybe_specify_new_frame_ref(
return new_self
return new_self

def lazy_get(self, ids: list, numeric_index: bool = False) -> "ModinDtypes":
def lazy_get(self, ids: list, numeric_index: bool = False) -> ModinDtypes:
"""
Get new ``ModinDtypes`` for a subset of columns without triggering any computations.
Expand Down Expand Up @@ -848,7 +845,7 @@ def lazy_get(self, ids: list, numeric_index: bool = False) -> "ModinDtypes":
return ModinDtypes(self._value.iloc[ids] if numeric_index else self._value[ids])

@classmethod
def concat(cls, values: list, axis: int = 0) -> "ModinDtypes":
def concat(cls, values: list, axis: int = 0) -> ModinDtypes:
"""
Concatenate dtypes.
Expand Down Expand Up @@ -892,7 +889,7 @@ def concat(cls, values: list, axis: int = 0) -> "ModinDtypes":
desc = pandas.concat(values)
return ModinDtypes(desc)

def set_index(self, new_index: Union[pandas.Index, "ModinIndex"]) -> "ModinDtypes":
def set_index(self, new_index: Union[pandas.Index, "ModinIndex"]) -> ModinDtypes:
"""
Set new column names for stored dtypes.
Expand Down Expand Up @@ -996,7 +993,7 @@ def __getattr__(self, name):
self.get()
return self._value.__getattribute__(name)

def copy(self) -> "ModinDtypes":
def copy(self) -> ModinDtypes:
"""
Copy an object without materializing the internal representation.
Expand Down Expand Up @@ -1200,7 +1197,7 @@ def _materialize_categories(self):

def get_categories_dtype(
cdt: Union[LazyProxyCategoricalDtype, pandas.CategoricalDtype]
):
) -> DtypeObj:
"""
Get the categories dtype.
Expand All @@ -1219,7 +1216,7 @@ def get_categories_dtype(
)


def extract_dtype(value):
def extract_dtype(value) -> DtypeObj | pandas.Series:
"""
Extract dtype(s) from the passed `value`.
Expand All @@ -1229,18 +1226,11 @@ def extract_dtype(value):
Returns
-------
numpy.dtype or pandas.Series of numpy.dtypes
DtypeObj or pandas.Series of DtypeObj
"""
from modin.pandas.utils import is_scalar

if hasattr(value, "dtype"):
return value.dtype
elif hasattr(value, "dtypes"):
return value.dtypes
elif is_scalar(value):
if value is None:
# previous type was object instead of 'float64'
return pandas.api.types.pandas_dtype(value)
return pandas.api.types.pandas_dtype(type(value))
else:
return np.array(value).dtype
try:
dtype = pandas.api.types.pandas_dtype(value)
except TypeError:
dtype = pandas.Series(value).dtype

return dtype
15 changes: 2 additions & 13 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,18 +2831,7 @@ def _set_item(df, row_loc): # pragma: no cover
if self._modin_frame.has_materialized_dtypes and is_scalar(item):
new_dtypes = self.dtypes.copy()
old_dtypes = new_dtypes[col_loc]

if hasattr(item, "dtype"):
# If we're dealing with a numpy scalar (np.int, np.datetime64, ...)
# we would like to get its internal dtype
item_type = item.dtype
elif hasattr(item, "to_numpy"):
# If we're dealing with a scalar that can be converted to numpy (for example pandas.Timestamp)
# we would like to convert it and get its proper internal dtype
item_type = item.to_numpy().dtype
else:
item_type = pandas.api.types.pandas_dtype(type(item))

item_type = extract_dtype(item)
if isinstance(old_dtypes, pandas.Series):
new_dtypes[col_loc] = [
find_common_type([dtype, item_type]) for dtype in old_dtypes.values
Expand Down Expand Up @@ -4536,7 +4525,7 @@ def iloc_mut(partition, row_internal_indices, col_internal_indices, item):
)
else:
broadcasted_item, broadcasted_dtypes = item, pandas.Series(
[np.array(item).dtype] * len(col_numeric_index)
[extract_dtype(item)] * len(col_numeric_index)
)

new_dtypes = None
Expand Down

0 comments on commit deddc14

Please sign in to comment.