From a0a47882dec80cedf7a7a9ce156ae861c48813ad Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Tue, 28 Jan 2025 20:37:20 +0100 Subject: [PATCH 1/2] feat: spark like list namespace and explode --- narwhals/_duckdb/expr_list.py | 6 ++- narwhals/_spark_like/dataframe.py | 62 ++++++++++++++++++++++++-- narwhals/_spark_like/expr.py | 5 +++ narwhals/_spark_like/expr_list.py | 20 +++++++++ narwhals/_spark_like/utils.py | 29 ++++++------ tests/expr_and_series/list/len_test.py | 4 +- tests/frame/explode_test.py | 10 ++--- tests/utils.py | 7 ++- 8 files changed, 116 insertions(+), 27 deletions(-) create mode 100644 narwhals/_spark_like/expr_list.py diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 134df90b6..a2277c4b8 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -5,14 +5,16 @@ from duckdb import FunctionExpression if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._duckdb.expr import DuckDBExpr class DuckDBExprListNamespace: - def __init__(self, expr: DuckDBExpr) -> None: + def __init__(self: Self, expr: DuckDBExpr) -> None: self._compliant_expr = expr - def len(self) -> DuckDBExpr: + def len(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("len", _input), "len", diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 13945b060..75443c30d 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -12,6 +12,7 @@ from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists +from narwhals.utils import import_dtypes_module from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import validate_backend_version @@ -19,6 +20,7 @@ if TYPE_CHECKING: from types import ModuleType + from pyspark.sql import Column from pyspark.sql import DataFrame from typing_extensions import Self @@ -46,7 +48,7 @@ def __init__( validate_backend_version(self._implementation, self._backend_version) @property - def _F(self) -> Any: # noqa: N802 + def _F(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import functions @@ -56,7 +58,7 @@ def _F(self) -> Any: # noqa: N802 return functions @property - def _native_dtypes(self) -> Any: + def _native_dtypes(self: Self) -> Any: if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import types @@ -66,7 +68,7 @@ def _native_dtypes(self) -> Any: return types @property - def _Window(self) -> Any: # noqa: N802 + def _Window(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import Window @@ -312,3 +314,57 @@ def join( return self._from_native_frame( self_native.join(other, on=left_on, how=how).select(col_order) ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + from narwhals.exceptions import InvalidOperationError + + dtypes = import_dtypes_module(self._version) + + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) + + native_frame = self._native_frame + column_names = self.columns + + def null_condition(col_name: str) -> Column: + return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0) + + if len(to_explode) == 1: + return self._from_native_frame( + native_frame.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != to_explode[0] + else self._F.explode(col_name).alias(col_name) + for col_name in column_names + ] + ).union( + native_frame.filter(null_condition(to_explode[0])).select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != to_explode[0] + else self._F.lit(None).alias(col_name) + for col_name in column_names + ] + ) + ) + ) + + msg = ( + "Exploding on multiple columns is not supported with SparkLike backend since " + "we cannot guarantee that the exploded columns have matching element counts." + ) + raise NotImplementedError(msg) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index f3a9de7f0..d3f590ad2 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -7,6 +7,7 @@ from typing import Sequence from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace +from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.utils import ExprKind @@ -556,3 +557,7 @@ def name(self: Self) -> SparkLikeExprNameNamespace: @property def dt(self: Self) -> SparkLikeExprDateTimeNamespace: return SparkLikeExprDateTimeNamespace(self) + + @property + def list(self: Self) -> SparkLikeExprListNamespace: + return SparkLikeExprListNamespace(self) diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py new file mode 100644 index 000000000..ba0dc3189 --- /dev/null +++ b/narwhals/_spark_like/expr_list.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprListNamespace: + def __init__(self: Self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def len(self: Self) -> SparkLikeExpr: + return self._compliant_expr._from_call( + self._compliant_expr._F.array_size, + "len", + expr_kind=self._compliant_expr._expr_kind, + ) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 82d7feb28..3a728d5c5 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -57,26 +57,25 @@ def native_to_narwhals_dtype( return dtypes.Int16() if isinstance(dtype, spark_types.ByteType): return dtypes.Int8() - string_types = [ - spark_types.StringType, - spark_types.VarcharType, - spark_types.CharType, - ] - if any(isinstance(dtype, t) for t in string_types): + if isinstance( + dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType) + ): return dtypes.String() if isinstance(dtype, spark_types.BooleanType): return dtypes.Boolean() if isinstance(dtype, spark_types.DateType): return dtypes.Date() - datetime_types = [ - spark_types.TimestampType, - spark_types.TimestampNTZType, - ] - if any(isinstance(dtype, t) for t in datetime_types): + if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)): return dtypes.Datetime() if isinstance(dtype, spark_types.DecimalType): # pragma: no cover # TODO(unassigned): cover this in dtypes_test.py return dtypes.Decimal() + if isinstance(dtype, spark_types.ArrayType): # pragma: no cover + return dtypes.List( + inner=native_to_narwhals_dtype( + dtype.elementType, version=version, spark_types=spark_types + ) + ) return dtypes.Unknown() @@ -105,8 +104,12 @@ def narwhals_to_native_dtype( msg = "Converting to Date or Datetime dtype is not supported yet" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover - msg = "Converting to List dtype is not supported yet" - raise NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + spark_types=spark_types, + ) + return spark_types.ArrayType(elementType=inner) if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover msg = "Converting to Struct dtype is not supported yet" raise NotImplementedError(msg) diff --git a/tests/expr_and_series/list/len_test.py b/tests/expr_and_series/list/len_test.py index 375cfc7d8..7066fc6cf 100644 --- a/tests/expr_and_series/list/len_test.py +++ b/tests/expr_and_series/list/len_test.py @@ -17,9 +17,7 @@ def test_len_expr( request: pytest.FixtureRequest, constructor: Constructor, ) -> None: - if any( - backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark") - ): + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index f3b096194..db5a4fc5a 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -40,7 +40,7 @@ def test_explode_single_col( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -110,7 +110,7 @@ def test_explode_shape_error( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -118,8 +118,8 @@ def test_explode_shape_error( request.applymarker(pytest.mark.xfail) with pytest.raises( - (ShapeError, PlShapeError), - match="exploded columns must have matching element counts", + (ShapeError, PlShapeError, NotImplementedError), + match=r".*exploded columns (must )?have matching element counts", ): _ = ( nw.from_native(constructor(data)) @@ -133,7 +133,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb", "pyspark")): + if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): diff --git a/tests/utils.py b/tests/utils.py index f4f612619..7174fbb9e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -66,7 +66,12 @@ def _sort_dict_by_key( ) -> dict[str, list[Any]]: # pragma: no cover sort_list = data_dict[key] sorted_indices = sorted( - range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i]) + range(len(sort_list)), + key=lambda i: ( + (sort_list[i] is None) + or (isinstance(sort_list[i], float) and math.isnan(sort_list[i])), + sort_list[i], + ), ) return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()} From f552cd098196db65e350662f37b92bbb24bb175c Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Thu, 30 Jan 2025 10:35:31 +0100 Subject: [PATCH 2/2] some cleanup and use F.explode_outer --- narwhals/_spark_like/dataframe.py | 44 +++++++++++-------------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 75443c30d..0661f3cfb 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -9,6 +9,7 @@ from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs +from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists @@ -20,7 +21,6 @@ if TYPE_CHECKING: from types import ModuleType - from pyspark.sql import Column from pyspark.sql import DataFrame from typing_extensions import Self @@ -316,8 +316,6 @@ def join( ) def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: - from narwhals.exceptions import InvalidOperationError - dtypes = import_dtypes_module(self._version) to_explode = ( @@ -339,32 +337,20 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel native_frame = self._native_frame column_names = self.columns - def null_condition(col_name: str) -> Column: - return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0) - - if len(to_explode) == 1: - return self._from_native_frame( - native_frame.select( - *[ - self._F.col(col_name).alias(col_name) - if col_name != to_explode[0] - else self._F.explode(col_name).alias(col_name) - for col_name in column_names - ] - ).union( - native_frame.filter(null_condition(to_explode[0])).select( - *[ - self._F.col(col_name).alias(col_name) - if col_name != to_explode[0] - else self._F.lit(None).alias(col_name) - for col_name in column_names - ] - ) - ) + if len(to_explode) != 1: + msg = ( + "Exploding on multiple columns is not supported with SparkLike backend since " + "we cannot guarantee that the exploded columns have matching element counts." ) + raise NotImplementedError(msg) - msg = ( - "Exploding on multiple columns is not supported with SparkLike backend since " - "we cannot guarantee that the exploded columns have matching element counts." + return self._from_native_frame( + native_frame.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != to_explode[0] + else self._F.explode_outer(col_name).alias(col_name) + for col_name in column_names + ] + ) ) - raise NotImplementedError(msg)