Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: spark like .list namespace and lazyframe .explode #1887

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions narwhals/_duckdb/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from duckdb import FunctionExpression

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._duckdb.expr import DuckDBExpr


class DuckDBExprListNamespace:
def __init__(self, expr: DuckDBExpr) -> None:
def __init__(self: Self, expr: DuckDBExpr) -> None:
self._compliant_expr = expr

def len(self) -> DuckDBExpr:
def len(self: Self) -> DuckDBExpr:
return self._compliant_expr._from_call(
lambda _input: FunctionExpression("len", _input),
"len",
Expand Down
62 changes: 59 additions & 3 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import import_dtypes_module
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import validate_backend_version

if TYPE_CHECKING:
from types import ModuleType

from pyspark.sql import Column
from pyspark.sql import DataFrame
from typing_extensions import Self

Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(
validate_backend_version(self._implementation, self._backend_version)

@property
def _F(self) -> Any: # noqa: N802
def _F(self: Self) -> Any: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import functions

Expand All @@ -56,7 +58,7 @@ def _F(self) -> Any: # noqa: N802
return functions

@property
def _native_dtypes(self) -> Any:
def _native_dtypes(self: Self) -> Any:
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import types

Expand All @@ -66,7 +68,7 @@ def _native_dtypes(self) -> Any:
return types

@property
def _Window(self) -> Any: # noqa: N802
def _Window(self: Self) -> Any: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import Window

Expand Down Expand Up @@ -312,3 +314,57 @@ def join(
return self._from_native_frame(
self_native.join(other, on=left_on, how=how).select(col_order)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
from narwhals.exceptions import InvalidOperationError
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)

native_frame = self._native_frame
column_names = self.columns

def null_condition(col_name: str) -> Column:
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)

if len(to_explode) == 1:
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
return self._from_native_frame(
native_frame.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != to_explode[0]
else self._F.explode(col_name).alias(col_name)
for col_name in column_names
]
).union(
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
native_frame.filter(null_condition(to_explode[0])).select(
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
*[
self._F.col(col_name).alias(col_name)
if col_name != to_explode[0]
else self._F.lit(None).alias(col_name)
for col_name in column_names
]
)
)
)

msg = (
"Exploding on multiple columns is not supported with SparkLike backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a way... this is so sad since I had a great way of doing it! But cannot verify that the element counts are same without triggering a collect

5 changes: 5 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence

from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.utils import ExprKind
Expand Down Expand Up @@ -556,3 +557,7 @@ def name(self: Self) -> SparkLikeExprNameNamespace:
@property
def dt(self: Self) -> SparkLikeExprDateTimeNamespace:
return SparkLikeExprDateTimeNamespace(self)

@property
def list(self: Self) -> SparkLikeExprListNamespace:
return SparkLikeExprListNamespace(self)
20 changes: 20 additions & 0 deletions narwhals/_spark_like/expr_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeExprListNamespace:
def __init__(self: Self, expr: SparkLikeExpr) -> None:
self._compliant_expr = expr

def len(self: Self) -> SparkLikeExpr:
return self._compliant_expr._from_call(
self._compliant_expr._F.array_size,
"len",
expr_kind=self._compliant_expr._expr_kind,
)
29 changes: 16 additions & 13 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,25 @@ def native_to_narwhals_dtype(
return dtypes.Int16()
if isinstance(dtype, spark_types.ByteType):
return dtypes.Int8()
string_types = [
spark_types.StringType,
spark_types.VarcharType,
spark_types.CharType,
]
if any(isinstance(dtype, t) for t in string_types):
if isinstance(
dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType)
):
return dtypes.String()
if isinstance(dtype, spark_types.BooleanType):
return dtypes.Boolean()
if isinstance(dtype, spark_types.DateType):
return dtypes.Date()
datetime_types = [
spark_types.TimestampType,
spark_types.TimestampNTZType,
]
if any(isinstance(dtype, t) for t in datetime_types):
if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)):
return dtypes.Datetime()
if isinstance(dtype, spark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
return dtypes.Decimal()
if isinstance(dtype, spark_types.ArrayType): # pragma: no cover
return dtypes.List(
inner=native_to_narwhals_dtype(
dtype.elementType, version=version, spark_types=spark_types
)
)
return dtypes.Unknown()


Expand Down Expand Up @@ -105,8 +104,12 @@ def narwhals_to_native_dtype(
msg = "Converting to Date or Datetime dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
raise NotImplementedError(msg)
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
)
return spark_types.ArrayType(elementType=inner)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/list/len_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def test_len_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand Down
10 changes: 5 additions & 5 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_explode_single_col(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark")
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb")
):
request.applymarker(pytest.mark.xfail)

Expand Down Expand Up @@ -110,16 +110,16 @@ def test_explode_shape_error(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark")
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(ShapeError, PlShapeError),
match="exploded columns must have matching element counts",
(ShapeError, PlShapeError, NotImplementedError),
match=r".*exploded columns (must )?have matching element counts",
):
_ = (
nw.from_native(constructor(data))
Expand All @@ -133,7 +133,7 @@ def test_explode_shape_error(
def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb", "pyspark")):
if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def _sort_dict_by_key(
) -> dict[str, list[Any]]: # pragma: no cover
sort_list = data_dict[key]
sorted_indices = sorted(
range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i])
range(len(sort_list)),
key=lambda i: (
(sort_list[i] is None)
or (isinstance(sort_list[i], float) and math.isnan(sort_list[i])),
sort_list[i],
),
)
return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()}

Expand Down
Loading