From 83c1b2aa35b41c294daa55f095118999dec130c6 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Thu, 30 Jan 2025 19:50:59 +0100 Subject: [PATCH] feat: spark like `.list` namespace and lazyframe `.explode` (#1887) * feat: spark like list namespace and explode * some cleanup and use F.explode_outer --- narwhals/_duckdb/expr_list.py | 6 ++-- narwhals/_spark_like/dataframe.py | 48 ++++++++++++++++++++++++-- narwhals/_spark_like/expr.py | 5 +++ narwhals/_spark_like/expr_list.py | 20 +++++++++++ narwhals/_spark_like/utils.py | 29 +++++++++------- tests/expr_and_series/list/len_test.py | 4 +-- tests/frame/explode_test.py | 10 +++--- tests/utils.py | 7 +++- 8 files changed, 102 insertions(+), 27 deletions(-) create mode 100644 narwhals/_spark_like/expr_list.py diff --git a/narwhals/_duckdb/expr_list.py b/narwhals/_duckdb/expr_list.py index 134df90b6..a2277c4b8 100644 --- a/narwhals/_duckdb/expr_list.py +++ b/narwhals/_duckdb/expr_list.py @@ -5,14 +5,16 @@ from duckdb import FunctionExpression if TYPE_CHECKING: + from typing_extensions import Self + from narwhals._duckdb.expr import DuckDBExpr class DuckDBExprListNamespace: - def __init__(self, expr: DuckDBExpr) -> None: + def __init__(self: Self, expr: DuckDBExpr) -> None: self._compliant_expr = expr - def len(self) -> DuckDBExpr: + def len(self: Self) -> DuckDBExpr: return self._compliant_expr._from_call( lambda _input: FunctionExpression("len", _input), "len", diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 13945b060..0661f3cfb 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -9,9 +9,11 @@ from narwhals._spark_like.utils import ExprKind from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs +from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists +from narwhals.utils import import_dtypes_module from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import validate_backend_version @@ -46,7 +48,7 @@ def __init__( validate_backend_version(self._implementation, self._backend_version) @property - def _F(self) -> Any: # noqa: N802 + def _F(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import functions @@ -56,7 +58,7 @@ def _F(self) -> Any: # noqa: N802 return functions @property - def _native_dtypes(self) -> Any: + def _native_dtypes(self: Self) -> Any: if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import types @@ -66,7 +68,7 @@ def _native_dtypes(self) -> Any: return types @property - def _Window(self) -> Any: # noqa: N802 + def _Window(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.duckdb import Window @@ -312,3 +314,43 @@ def join( return self._from_native_frame( self_native.join(other, on=left_on, how=how).select(col_order) ) + + def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self: + dtypes = import_dtypes_module(self._version) + + to_explode = ( + [columns, *more_columns] + if isinstance(columns, str) + else [*columns, *more_columns] + ) + schema = self.collect_schema() + for col_to_explode in to_explode: + dtype = schema[col_to_explode] + + if dtype != dtypes.List: + msg = ( + f"`explode` operation not supported for dtype `{dtype}`, " + "expected List type" + ) + raise InvalidOperationError(msg) + + native_frame = self._native_frame + column_names = self.columns + + if len(to_explode) != 1: + msg = ( + "Exploding on multiple columns is not supported with SparkLike backend since " + "we cannot guarantee that the exploded columns have matching element counts." + ) + raise NotImplementedError(msg) + + return self._from_native_frame( + native_frame.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != to_explode[0] + else self._F.explode_outer(col_name).alias(col_name) + for col_name in column_names + ] + ) + ) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index f3a9de7f0..d3f590ad2 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -7,6 +7,7 @@ from typing import Sequence from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace +from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.utils import ExprKind @@ -556,3 +557,7 @@ def name(self: Self) -> SparkLikeExprNameNamespace: @property def dt(self: Self) -> SparkLikeExprDateTimeNamespace: return SparkLikeExprDateTimeNamespace(self) + + @property + def list(self: Self) -> SparkLikeExprListNamespace: + return SparkLikeExprListNamespace(self) diff --git a/narwhals/_spark_like/expr_list.py b/narwhals/_spark_like/expr_list.py new file mode 100644 index 000000000..ba0dc3189 --- /dev/null +++ b/narwhals/_spark_like/expr_list.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self + + from narwhals._spark_like.expr import SparkLikeExpr + + +class SparkLikeExprListNamespace: + def __init__(self: Self, expr: SparkLikeExpr) -> None: + self._compliant_expr = expr + + def len(self: Self) -> SparkLikeExpr: + return self._compliant_expr._from_call( + self._compliant_expr._F.array_size, + "len", + expr_kind=self._compliant_expr._expr_kind, + ) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 82d7feb28..3a728d5c5 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -57,26 +57,25 @@ def native_to_narwhals_dtype( return dtypes.Int16() if isinstance(dtype, spark_types.ByteType): return dtypes.Int8() - string_types = [ - spark_types.StringType, - spark_types.VarcharType, - spark_types.CharType, - ] - if any(isinstance(dtype, t) for t in string_types): + if isinstance( + dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType) + ): return dtypes.String() if isinstance(dtype, spark_types.BooleanType): return dtypes.Boolean() if isinstance(dtype, spark_types.DateType): return dtypes.Date() - datetime_types = [ - spark_types.TimestampType, - spark_types.TimestampNTZType, - ] - if any(isinstance(dtype, t) for t in datetime_types): + if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)): return dtypes.Datetime() if isinstance(dtype, spark_types.DecimalType): # pragma: no cover # TODO(unassigned): cover this in dtypes_test.py return dtypes.Decimal() + if isinstance(dtype, spark_types.ArrayType): # pragma: no cover + return dtypes.List( + inner=native_to_narwhals_dtype( + dtype.elementType, version=version, spark_types=spark_types + ) + ) return dtypes.Unknown() @@ -105,8 +104,12 @@ def narwhals_to_native_dtype( msg = "Converting to Date or Datetime dtype is not supported yet" raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover - msg = "Converting to List dtype is not supported yet" - raise NotImplementedError(msg) + inner = narwhals_to_native_dtype( + dtype.inner, # type: ignore[union-attr] + version=version, + spark_types=spark_types, + ) + return spark_types.ArrayType(elementType=inner) if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover msg = "Converting to Struct dtype is not supported yet" raise NotImplementedError(msg) diff --git a/tests/expr_and_series/list/len_test.py b/tests/expr_and_series/list/len_test.py index 375cfc7d8..7066fc6cf 100644 --- a/tests/expr_and_series/list/len_test.py +++ b/tests/expr_and_series/list/len_test.py @@ -17,9 +17,7 @@ def test_len_expr( request: pytest.FixtureRequest, constructor: Constructor, ) -> None: - if any( - backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark") - ): + if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index f3b096194..db5a4fc5a 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -40,7 +40,7 @@ def test_explode_single_col( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -110,7 +110,7 @@ def test_explode_shape_error( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") ): request.applymarker(pytest.mark.xfail) @@ -118,8 +118,8 @@ def test_explode_shape_error( request.applymarker(pytest.mark.xfail) with pytest.raises( - (ShapeError, PlShapeError), - match="exploded columns must have matching element counts", + (ShapeError, PlShapeError, NotImplementedError), + match=r".*exploded columns (must )?have matching element counts", ): _ = ( nw.from_native(constructor(data)) @@ -133,7 +133,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb", "pyspark")): + if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): diff --git a/tests/utils.py b/tests/utils.py index f4f612619..7174fbb9e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -66,7 +66,12 @@ def _sort_dict_by_key( ) -> dict[str, list[Any]]: # pragma: no cover sort_list = data_dict[key] sorted_indices = sorted( - range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i]) + range(len(sort_list)), + key=lambda i: ( + (sort_list[i] is None) + or (isinstance(sort_list[i], float) and math.isnan(sort_list[i])), + sort_list[i], + ), ) return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()}