Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(typing): Resolve all mypy & pyright errors for _arrow #2007

Merged
merged 63 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
e7f465f
fix(typing): Resolve all `pyright` warnings for `_arrow`
dangotbanned Feb 13, 2025
0015ac1
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 13, 2025
1d24dec
chore(typing): add extra ignore
dangotbanned Feb 13, 2025
286beeb
ci: omit `_arrow.typing.py` from coverage
dangotbanned Feb 13, 2025
757c6bd
fix(typing): error: Name "ser" already defined on line 59 [no-redef]
dangotbanned Feb 13, 2025
b830aa2
fix(DRAFT): try reordering `maybe_extract_py_scalar` overloads
dangotbanned Feb 13, 2025
5cadbed
ci(typing): re-enable `pyarrow-stubs`
dangotbanned Feb 13, 2025
1e6eb75
fix: only check for `.ordered` when type can have property
dangotbanned Feb 13, 2025
51ad255
fix(typing): use `utils.chunked_array`
dangotbanned Feb 13, 2025
658c963
fix(typing): misc assignment/redef errors
dangotbanned Feb 13, 2025
84033f1
fix(typing): error: "Table" has no attribute "__iter__" (not iterable…
dangotbanned Feb 13, 2025
f35b01e
chore(typing): lie about `broadcast_series`
dangotbanned Feb 13, 2025
1b537e7
fix(typing): error: Value of type variable "_ArrayT" of "concat_array…
dangotbanned Feb 13, 2025
89672e2
chore(DRAFT): add comment on `pyarrow.interchange.from_dataframe`
dangotbanned Feb 14, 2025
ce986e1
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 14, 2025
4d4a722
fix(typing): ambiguous `pyarrow.interchange.from_dataframe` import
dangotbanned Feb 14, 2025
e2940e7
chore(typing): help `mypy` match `@overload`
dangotbanned Feb 14, 2025
c7f23d5
chore(typing): mark `pc.max_element_wise` as `Incomplete` everywhere
dangotbanned Feb 14, 2025
b7d149d
chore(typing): more help `mypy` match `@overload`
dangotbanned Feb 14, 2025
1445db8
test(typing): resolve `[assignment]`
dangotbanned Feb 14, 2025
8c98172
test(typing): more `[assignment]`
dangotbanned Feb 14, 2025
a246ea7
test(typing): ignore `[arg-type]` when raising
dangotbanned Feb 14, 2025
cac7dbb
test(typing): even more `[assignment]`
dangotbanned Feb 14, 2025
35f801b
test(typing): even more helping `@overload` match
dangotbanned Feb 14, 2025
4a9ff86
test(typing): more ignore `[arg-type]` when raising
dangotbanned Feb 14, 2025
f86bf10
ci: ignore coverage within a test?
dangotbanned Feb 14, 2025
cf27f7d
test(typing): fix indirect `[assignment]` for `mypy`
dangotbanned Feb 14, 2025
3bdf0e8
refactor: replace `pa.scalar` w/ `lit` alias
dangotbanned Feb 14, 2025
27da267
fix(typing): use `nulls_like` in `Series.shift`
dangotbanned Feb 14, 2025
c43d092
fix(typing): widen `ArrowSeries.(mean|median)`
dangotbanned Feb 14, 2025
b549ee5
ci(typing): remove `pyarrow` comments
dangotbanned Feb 14, 2025
483db21
refactor: reuse `nulls_like`
dangotbanned Feb 14, 2025
44c8c52
feat(typing): correct methods returning `ArrowSeries[pa.BooleanScalar]`
dangotbanned Feb 14, 2025
bcec6f1
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 14, 2025
1b065f4
Merge branch 'main' into typing-major-fixing-1
dangotbanned Feb 14, 2025
41c9661
feat(typing): `series_str` return annotations
dangotbanned Feb 14, 2025
85b01a6
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 15, 2025
adef6a0
refactor(typing): use `Incomplete` for `nulls_like`
dangotbanned Feb 15, 2025
4a2cfc5
refactor(typing): mark `pc.binary_join_element_wise` as `Incomplete`
dangotbanned Feb 15, 2025
f704424
refactor(DRAFT): try simplifying `ArrowSeries.to_pandas`
dangotbanned Feb 15, 2025
f8617ff
revert: undo `ArrowSeries.to_pandas` change
dangotbanned Feb 15, 2025
b27665f
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 15, 2025
ba2b3a5
refactor(typing): narrow to `aggs: list[tuple[str, str, Any]]`
dangotbanned Feb 16, 2025
86e0c2a
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 16, 2025
8d13198
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 17, 2025
e5887eb
remove unneeded no cover
MarcoGorelli Feb 17, 2025
a947cf6
arrow scalar to _arrow/typing
MarcoGorelli Feb 17, 2025
5dcd2f3
rename to ArrowScalarT_co
MarcoGorelli Feb 17, 2025
44560d1
refactor(typing): Define `__lib_pxi.types._BasicDataType` internally
dangotbanned Feb 17, 2025
481a3e4
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 17, 2025
9c267c8
fix: Make `ArrowScalatT_co` available at runtime
dangotbanned Feb 17, 2025
59af046
simplify
MarcoGorelli Feb 17, 2025
a18e1bc
Merge branch 'typing-major-fixing-1' of github.com:narwhals-dev/narwh…
MarcoGorelli Feb 17, 2025
bc02cdb
refactor: avoid `pyarrow.__lib_pxi.types` in `utils`
dangotbanned Feb 17, 2025
8c8dbe7
Merge branch 'typing-major-fixing-1' of https://github.com/narwhals-d…
dangotbanned Feb 17, 2025
8ca893c
fixup
MarcoGorelli Feb 17, 2025
e6cecbc
Merge branch 'typing-major-fixing-1' of https://github.com/narwhals-d…
dangotbanned Feb 17, 2025
aa62a32
remove unalive code
MarcoGorelli Feb 17, 2025
9507261
Merge branch 'typing-major-fixing-1' of github.com:narwhals-dev/narwh…
MarcoGorelli Feb 17, 2025
3ea9674
last one
MarcoGorelli Feb 17, 2025
9bba7aa
fixup tests
MarcoGorelli Feb 17, 2025
1ceda96
fixup typing
MarcoGorelli Feb 17, 2025
837f0e6
pyright ignore
dangotbanned Feb 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
Expand Down Expand Up @@ -36,9 +37,6 @@

import pandas as pd
import polars as pl
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource]
Indices,
)
from pyarrow._stubs_typing import Order # pyright: ignore[reportMissingModuleSource]
from typing_extensions import Self
from typing_extensions import TypeAlias
Expand All @@ -47,6 +45,8 @@
from narwhals._arrow.group_by import ArrowGroupBy
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import Indices
from narwhals._arrow.typing import Mask
from narwhals.dtypes import DType
from narwhals.typing import SizeUnit
from narwhals.typing import _1DArray
Expand Down Expand Up @@ -133,7 +133,7 @@ def __len__(self: Self) -> int:
return len(self._native_frame)

def row(self: Self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self._native_frame)
return tuple(col[index] for col in self._native_frame.itercolumns())

@overload
def rows(self: Self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
Expand Down Expand Up @@ -165,7 +165,7 @@ def iter_rows(
for i in range(0, num_rows, buffer_size):
yield from df[i : i + buffer_size].to_pylist()

def get_column(self: Self, name: str) -> ArrowSeries:
def get_column(self: Self, name: str) -> ArrowSeries[Any]:
from narwhals._arrow.series import ArrowSeries

if not isinstance(name, str):
Expand All @@ -185,7 +185,7 @@ def __array__(self: Self, dtype: Any, copy: bool | None) -> _2DArray:
@overload
def __getitem__( # type: ignore[overload-overlap, unused-ignore]
self: Self, item: str | tuple[slice | Sequence[int] | _1DArray, int | str]
) -> ArrowSeries: ...
) -> ArrowSeries[Any]: ...
@overload
def __getitem__(
self: Self,
Expand Down Expand Up @@ -214,7 +214,7 @@ def __getitem__(
slice | Sequence[int] | _1DArray, slice | Sequence[int] | Sequence[str]
]
),
) -> ArrowSeries | Self:
) -> ArrowSeries[Any] | Self:
if isinstance(item, tuple):
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # pyright: ignore[reportAssignmentType]

Expand Down Expand Up @@ -345,7 +345,7 @@ def aggregate(self: Self, *exprs: ArrowExpr) -> Self:
return self.select(*exprs)

def select(self: Self, *exprs: ArrowExpr) -> Self:
new_series: list[ArrowSeries] = evaluate_into_exprs(self, *exprs)
new_series: list[ArrowSeries[Any]] = evaluate_into_exprs(self, *exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._from_native_frame(
Expand All @@ -357,7 +357,7 @@ def select(self: Self, *exprs: ArrowExpr) -> Self:

def with_columns(self: Self, *exprs: ArrowExpr) -> Self:
native_frame = self._native_frame
new_columns: list[ArrowSeries] = evaluate_into_exprs(self, *exprs)
new_columns: list[ArrowSeries[Any]] = evaluate_into_exprs(self, *exprs)

length = len(self)
columns = self.columns
Expand Down Expand Up @@ -497,14 +497,16 @@ def to_numpy(self: Self) -> _2DArray:
return arr

@overload
def to_dict(self: Self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
def to_dict(
self: Self, *, as_series: Literal[True]
) -> dict[str, ArrowSeries[Any]]: ...

@overload
def to_dict(self: Self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...

def to_dict(
self: Self, *, as_series: bool
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
) -> dict[str, ArrowSeries[Any]] | dict[str, list[Any]]:
df = self._native_frame

names_and_values = zip(df.column_names, df.columns)
Expand Down Expand Up @@ -532,9 +534,9 @@ def with_row_index(self: Self, name: str) -> Self:
df.append_column(name, row_indices).select([name, *cols])
)

def filter(self: Self, predicate: ArrowExpr | list[bool]) -> Self:
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native = predicate
mask_native: Mask = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask = evaluate_into_exprs(self, predicate)[0]
Expand Down Expand Up @@ -703,7 +705,7 @@ def write_csv(self: Self, file: str | Path | BytesIO | None) -> str | None:
pa_csv.write_csv(pa_table, file)
return None

def is_unique(self: Self) -> ArrowSeries:
def is_unique(self: Self) -> ArrowSeries[Any]:
from narwhals._arrow.series import ArrowSeries

col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
Expand Down Expand Up @@ -745,17 +747,14 @@ def unique(

agg_func = agg_func_map[keep]
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
keep_idx = (
keep_idx_native = (
df.append_column(col_token, pa.array(np.arange(len(self))))
.group_by(subset)
.aggregate([(col_token, agg_func)])
.column(f"{col_token}_{agg_func}")
)

return self._from_native_frame(
pc.take(df, keep_idx), # type: ignore[call-overload, unused-ignore]
validate_column_names=False,
)
indices = cast("Indices", keep_idx_native)
return self._from_native_frame(df.take(indices), validate_column_names=False)

keep_idx = self.simple_select(*subset).is_unique()
plx = self.__narwhals_namespace__()
Expand Down Expand Up @@ -804,30 +803,28 @@ def unpivot(
on_: list[str] = (
[c for c in self.columns if c not in index_] if on is None else on
)

promote_kwargs: dict[Literal["promote_options"], PromoteOptions] = (
{"promote_options": "permissive"}
concat = (
partial(pa.concat_tables, promote_options="permissive")
if self._backend_version >= (14, 0, 0)
else {}
else pa.concat_tables
)
names = [*index_, variable_name, value_name]
return self._from_native_frame(
pa.concat_tables(
concat(
[
pa.Table.from_arrays(
[
*(native_frame.column(idx_col) for idx_col in index_),
cast(
"pa.ChunkedArray",
"pa.ChunkedArray[Any]",
pa.array([on_col] * n_rows, pa.string()),
),
native_frame.column(on_col),
],
names=names,
)
for on_col in on_
],
**promote_kwargs,
]
)
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
Expand Down
16 changes: 8 additions & 8 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from narwhals.utils import Version


class ArrowExpr(CompliantExpr[ArrowSeries]):
class ArrowExpr(CompliantExpr[ArrowSeries[Any]]):
_implementation: Implementation = Implementation.PYARROW

def __init__(
self: Self,
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries[Any]]],
*,
depth: int,
function_name: str,
Expand All @@ -48,7 +48,7 @@ def __init__(
self._depth = depth
self._function_name = function_name
self._depth = depth
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
Expand All @@ -57,7 +57,7 @@ def __init__(
def __repr__(self: Self) -> str: # pragma: no cover
return f"ArrowExpr(depth={self._depth}, function_name={self._function_name}, "

def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries[Any]]:
return self._call(df)

@classmethod
Expand All @@ -69,7 +69,7 @@ def from_column_names(
) -> Self:
from narwhals._arrow.series import ArrowSeries

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def func(df: ArrowDataFrame) -> list[ArrowSeries[Any]]:
try:
return [
ArrowSeries(
Expand Down Expand Up @@ -106,7 +106,7 @@ def from_column_indices(
) -> Self:
from narwhals._arrow.series import ArrowSeries

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def func(df: ArrowDataFrame) -> list[ArrowSeries[Any]]:
return [
ArrowSeries(
df._native_frame[column_index],
Expand Down Expand Up @@ -370,7 +370,7 @@ def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self:
)

def over(self: Self, keys: list[str]) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def func(df: ArrowDataFrame) -> list[ArrowSeries[Any]]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if overlap := set(output_names).intersection(keys):
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
Expand Down Expand Up @@ -406,7 +406,7 @@ def map_batches(
function: Callable[[Any], Any],
return_dtype: DType | None,
) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def func(df: ArrowDataFrame) -> list[ArrowSeries[Any]]:
input_series_list = self._call(df)
output_names = [input_series.name for input_series in input_series_list]
result = [function(series) for series in input_series_list]
Expand Down
28 changes: 20 additions & 8 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
from typing import cast

import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -68,7 +69,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
)
raise ValueError(msg)

aggs: list[tuple[str, str, pc.FunctionOptions | None]] = []
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
aggs: list[Any] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()

Expand All @@ -91,7 +92,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:

function_name = re.sub(r"(\w+->)", "", expr._function_name)
if function_name in {"std", "var"}:
option = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
option: Any = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
Expand Down Expand Up @@ -139,15 +140,26 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:

def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
null_token = "__null_token_value__" # noqa: S105
null_token: str = "__null_token_value__" # noqa: S105

table = self._df._native_frame
key_values = pc.binary_join_element_wise(
*[pc.cast(table[key], pa.string()) for key in self._keys],
"",
null_handling="replace",
null_replacement=null_token,
# NOTE: stubs fail in multiple places for `ChunkedArray`
it = cast(
"Iterator[pa.StringArray]",
(table[key].cast(pa.string()) for key in self._keys),
)
if TYPE_CHECKING:
# NOTE: stubs indicate `separator` would get appended to the end, instead of between elements
key_values = pc.binary_join_element_wise(
*it, null_handling="replace", null_replacement=null_token
)
else:
key_values = pc.binary_join_element_wise(
*it,
"",
null_handling="replace",
null_replacement=null_token,
)
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
table = table.add_column(i=0, field_=col_token, column=key_values)

yield from (
Expand Down
Loading
Loading