Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(typing): Resolve all mypy & pyright errors for _arrow #2007

Merged
merged 63 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
e7f465f
fix(typing): Resolve all `pyright` warnings for `_arrow`
dangotbanned Feb 13, 2025
0015ac1
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 13, 2025
1d24dec
chore(typing): add extra ignore
dangotbanned Feb 13, 2025
286beeb
ci: omit `_arrow.typing.py` from coverage
dangotbanned Feb 13, 2025
757c6bd
fix(typing): error: Name "ser" already defined on line 59 [no-redef]
dangotbanned Feb 13, 2025
b830aa2
fix(DRAFT): try reordering `maybe_extract_py_scalar` overloads
dangotbanned Feb 13, 2025
5cadbed
ci(typing): re-enable `pyarrow-stubs`
dangotbanned Feb 13, 2025
1e6eb75
fix: only check for `.ordered` when type can have property
dangotbanned Feb 13, 2025
51ad255
fix(typing): use `utils.chunked_array`
dangotbanned Feb 13, 2025
658c963
fix(typing): misc assignment/redef errors
dangotbanned Feb 13, 2025
84033f1
fix(typing): error: "Table" has no attribute "__iter__" (not iterable…
dangotbanned Feb 13, 2025
f35b01e
chore(typing): lie about `broadcast_series`
dangotbanned Feb 13, 2025
1b537e7
fix(typing): error: Value of type variable "_ArrayT" of "concat_array…
dangotbanned Feb 13, 2025
89672e2
chore(DRAFT): add comment on `pyarrow.interchange.from_dataframe`
dangotbanned Feb 14, 2025
ce986e1
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 14, 2025
4d4a722
fix(typing): ambiguous `pyarrow.interchange.from_dataframe` import
dangotbanned Feb 14, 2025
e2940e7
chore(typing): help `mypy` match `@overload`
dangotbanned Feb 14, 2025
c7f23d5
chore(typing): mark `pc.max_element_wise` as `Incomplete` everywhere
dangotbanned Feb 14, 2025
b7d149d
chore(typing): more help `mypy` match `@overload`
dangotbanned Feb 14, 2025
1445db8
test(typing): resolve `[assignment]`
dangotbanned Feb 14, 2025
8c98172
test(typing): more `[assignment]`
dangotbanned Feb 14, 2025
a246ea7
test(typing): ignore `[arg-type]` when raising
dangotbanned Feb 14, 2025
cac7dbb
test(typing): even more `[assignment]`
dangotbanned Feb 14, 2025
35f801b
test(typing): even more helping `@overload` match
dangotbanned Feb 14, 2025
4a9ff86
test(typing): more ignore `[arg-type]` when raising
dangotbanned Feb 14, 2025
f86bf10
ci: ignore coverage within a test?
dangotbanned Feb 14, 2025
cf27f7d
test(typing): fix indirect `[assignment]` for `mypy`
dangotbanned Feb 14, 2025
3bdf0e8
refactor: replace `pa.scalar` w/ `lit` alias
dangotbanned Feb 14, 2025
27da267
fix(typing): use `nulls_like` in `Series.shift`
dangotbanned Feb 14, 2025
c43d092
fix(typing): widen `ArrowSeries.(mean|median)`
dangotbanned Feb 14, 2025
b549ee5
ci(typing): remove `pyarrow` comments
dangotbanned Feb 14, 2025
483db21
refactor: reuse `nulls_like`
dangotbanned Feb 14, 2025
44c8c52
feat(typing): correct methods returning `ArrowSeries[pa.BooleanScalar]`
dangotbanned Feb 14, 2025
bcec6f1
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 14, 2025
1b065f4
Merge branch 'main' into typing-major-fixing-1
dangotbanned Feb 14, 2025
41c9661
feat(typing): `series_str` return annotations
dangotbanned Feb 14, 2025
85b01a6
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 15, 2025
adef6a0
refactor(typing): use `Incomplete` for `nulls_like`
dangotbanned Feb 15, 2025
4a2cfc5
refactor(typing): mark `pc.binary_join_element_wise` as `Incomplete`
dangotbanned Feb 15, 2025
f704424
refactor(DRAFT): try simplifying `ArrowSeries.to_pandas`
dangotbanned Feb 15, 2025
f8617ff
revert: undo `ArrowSeries.to_pandas` change
dangotbanned Feb 15, 2025
b27665f
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 15, 2025
ba2b3a5
refactor(typing): narrow to `aggs: list[tuple[str, str, Any]]`
dangotbanned Feb 16, 2025
86e0c2a
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 16, 2025
8d13198
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 17, 2025
e5887eb
remove unneeded no cover
MarcoGorelli Feb 17, 2025
a947cf6
arrow scalar to _arrow/typing
MarcoGorelli Feb 17, 2025
5dcd2f3
rename to ArrowScalarT_co
MarcoGorelli Feb 17, 2025
44560d1
refactor(typing): Define `__lib_pxi.types._BasicDataType` internally
dangotbanned Feb 17, 2025
481a3e4
Merge remote-tracking branch 'upstream/main' into typing-major-fixing-1
dangotbanned Feb 17, 2025
9c267c8
fix: Make `ArrowScalatT_co` available at runtime
dangotbanned Feb 17, 2025
59af046
simplify
MarcoGorelli Feb 17, 2025
a18e1bc
Merge branch 'typing-major-fixing-1' of github.com:narwhals-dev/narwh…
MarcoGorelli Feb 17, 2025
bc02cdb
refactor: avoid `pyarrow.__lib_pxi.types` in `utils`
dangotbanned Feb 17, 2025
8c8dbe7
Merge branch 'typing-major-fixing-1' of https://github.com/narwhals-d…
dangotbanned Feb 17, 2025
8ca893c
fixup
MarcoGorelli Feb 17, 2025
e6cecbc
Merge branch 'typing-major-fixing-1' of https://github.com/narwhals-d…
dangotbanned Feb 17, 2025
aa62a32
remove unalive code
MarcoGorelli Feb 17, 2025
9507261
Merge branch 'typing-major-fixing-1' of github.com:narwhals-dev/narwh…
MarcoGorelli Feb 17, 2025
3ea9674
last one
MarcoGorelli Feb 17, 2025
9bba7aa
fixup tests
MarcoGorelli Feb 17, 2025
1ceda96
fixup typing
MarcoGorelli Feb 17, 2025
837f0e6
pyright ignore
dangotbanned Feb 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 22 additions & 23 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
Expand Down Expand Up @@ -36,17 +37,17 @@

import pandas as pd
import polars as pl
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource]
Indices,
)
from pyarrow._stubs_typing import Order # pyright: ignore[reportMissingModuleSource]
from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.group_by import ArrowGroupBy
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._arrow.typing import Indices
from narwhals._arrow.typing import Mask
from narwhals._arrow.typing import Order
from narwhals.dtypes import DType
from narwhals.typing import SizeUnit
from narwhals.typing import _1DArray
Expand Down Expand Up @@ -133,7 +134,7 @@ def __len__(self: Self) -> int:
return len(self._native_frame)

def row(self: Self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self._native_frame)
return tuple(col[index] for col in self._native_frame.itercolumns())

@overload
def rows(self: Self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
Expand Down Expand Up @@ -371,7 +372,9 @@ def with_columns(self: Self, *exprs: ArrowExpr) -> Self:

native_frame = (
native_frame.set_column(
columns.index(col_name), field_=col_name, column=column
columns.index(col_name),
field_=col_name,
column=column, # type: ignore[arg-type]
)
if col_name in columns
else native_frame.append_column(field_=col_name, column=column)
Expand Down Expand Up @@ -532,17 +535,18 @@ def with_row_index(self: Self, name: str) -> Self:
df.append_column(name, row_indices).select([name, *cols])
)

def filter(self: Self, predicate: ArrowExpr | list[bool]) -> Self:
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native = predicate
mask_native: Mask | ArrowChunkedArray = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask = evaluate_into_exprs(self, predicate)[0]
mask_native = broadcast_and_extract_dataframe_comparand(
length=len(self), other=mask, backend_version=self._backend_version
)
return self._from_native_frame(
self._native_frame.filter(mask_native), validate_column_names=False
self._native_frame.filter(mask_native), # pyright: ignore[reportArgumentType]
validate_column_names=False,
)

def head(self: Self, n: int) -> Self:
Expand Down Expand Up @@ -745,17 +749,14 @@ def unique(

agg_func = agg_func_map[keep]
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
keep_idx = (
keep_idx_native = (
df.append_column(col_token, pa.array(np.arange(len(self))))
.group_by(subset)
.aggregate([(col_token, agg_func)])
.column(f"{col_token}_{agg_func}")
)

return self._from_native_frame(
pc.take(df, keep_idx), # type: ignore[call-overload, unused-ignore]
validate_column_names=False,
)
indices = cast("Indices", keep_idx_native)
return self._from_native_frame(df.take(indices), validate_column_names=False)

keep_idx = self.simple_select(*subset).is_unique()
plx = self.__narwhals_namespace__()
Expand Down Expand Up @@ -804,30 +805,28 @@ def unpivot(
on_: list[str] = (
[c for c in self.columns if c not in index_] if on is None else on
)

promote_kwargs: dict[Literal["promote_options"], PromoteOptions] = (
{"promote_options": "permissive"}
concat = (
partial(pa.concat_tables, promote_options="permissive")
if self._backend_version >= (14, 0, 0)
else {}
else pa.concat_tables
)
names = [*index_, variable_name, value_name]
return self._from_native_frame(
pa.concat_tables(
concat(
[
pa.Table.from_arrays(
[
*(native_frame.column(idx_col) for idx_col in index_),
cast(
"pa.ChunkedArray",
"ArrowChunkedArray",
pa.array([on_col] * n_rows, pa.string()),
),
native_frame.column(on_col),
],
names=names,
)
for on_col in on_
],
**promote_kwargs,
]
)
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self._depth = depth
self._function_name = function_name
self._depth = depth
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
Expand Down
23 changes: 15 additions & 8 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
from typing import cast

import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -18,6 +19,7 @@

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import Incomplete

POLARS_TO_ARROW_AGGREGATIONS = {
"sum": "sum",
Expand Down Expand Up @@ -68,7 +70,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:
)
raise ValueError(msg)

aggs: list[tuple[str, str, pc.FunctionOptions | None]] = []
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
aggs: list[tuple[str, str, Any]] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()

Expand All @@ -91,7 +93,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:

function_name = re.sub(r"(\w+->)", "", expr._function_name)
if function_name in {"std", "var"}:
option = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
option: Any = pc.VarianceOptions(ddof=expr._kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
Expand Down Expand Up @@ -139,14 +141,19 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame:

def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
null_token = "__null_token_value__" # noqa: S105
null_token: str = "__null_token_value__" # noqa: S105

table = self._df._native_frame
key_values = pc.binary_join_element_wise(
*[pc.cast(table[key], pa.string()) for key in self._keys],
"",
null_handling="replace",
null_replacement=null_token,
# NOTE: stubs fail in multiple places for `ChunkedArray`
it = cast(
"Iterator[pa.StringArray]",
(table[key].cast(pa.string()) for key in self._keys),
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
key_values = concat_str(
*it, "", null_handling="replace", null_replacement=null_token
)
table = table.add_column(i=0, field_=col_token, column=key_values)

Expand Down
53 changes: 30 additions & 23 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import diagonal_concat
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import nulls_like
from narwhals._arrow.utils import vertical_concat
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
Expand All @@ -31,6 +32,7 @@

from typing_extensions import Self

from narwhals._arrow.typing import Incomplete
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
Expand Down Expand Up @@ -254,13 +256,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def min_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = [s for _expr in exprs for s in _expr(df)]
# NOTE: Stubs copy the wrong signature https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L963
min_element_wise: Incomplete = pc.min_element_wise
native_series = reduce(
min_element_wise,
[s._native_series for s in series],
init_series._native_series,
)
return [
ArrowSeries(
native_series=reduce(
pc.min_element_wise,
[s._native_series for s in series],
init_series._native_series,
),
native_series,
name=init_series.name,
backend_version=self._backend_version,
version=self._version,
Expand All @@ -279,13 +284,17 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def max_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = [s for _expr in exprs for s in _expr(df)]
# NOTE: stubs are missing `ChunkedArray` support
# https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L948-L954
max_element_wise: Incomplete = pc.max_element_wise
native_series = reduce(
max_element_wise,
[s._native_series for s in series],
init_series._native_series,
)
return [
ArrowSeries(
native_series=reduce(
pc.max_element_wise,
[s._native_series for s in series],
init_series._native_series,
),
native_series,
name=init_series.name,
backend_version=self._backend_version,
version=self._version,
Expand Down Expand Up @@ -347,18 +356,19 @@ def concat_str(
dtypes = import_dtypes_module(self._version)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
compliant_series_list = [
compliant_series_list: list[ArrowSeries] = [
s for _expr in exprs for s in _expr.cast(dtypes.String())(df)
]
null_handling = "skip" if ignore_nulls else "emit_null"
result_series = pc.binary_join_element_wise(
*(s._native_series for s in compliant_series_list),
separator,
null_handling=null_handling,
null_handling: Literal["skip", "emit_null"] = (
"skip" if ignore_nulls else "emit_null"
)
it = (s._native_series for s in compliant_series_list)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
return [
ArrowSeries(
native_series=result_series,
native_series=concat_str(*it, separator, null_handling=null_handling),
name=compliant_series_list[0].name,
backend_version=self._backend_version,
version=self._version,
Expand Down Expand Up @@ -410,14 +420,11 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
condition_native, value_series_native = broadcast_series(
[condition, value_series]
)

if self._otherwise_value is None:
otherwise_native = pa.repeat(
pa.scalar(None, type=value_series_native.type), len(condition_native)
)
otherwise_null = nulls_like(len(condition_native), value_series)
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_native)
pc.if_else(condition_native, value_series_native, otherwise_null)
)
]
if isinstance(self._otherwise_value, ArrowExpr):
Expand Down Expand Up @@ -474,7 +481,7 @@ def __init__(
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue]
self._alias_output_names = alias_output_names
self._kwargs = kwargs

Expand Down
Loading
Loading