Skip to content

Commit

Permalink
feat: spark like .list namespace and lazyframe .explode (#1887)
Browse files Browse the repository at this point in the history
* feat: spark like list namespace and explode

* some cleanup and use F.explode_outer
  • Loading branch information
FBruzzesi authored Jan 30, 2025
1 parent f644931 commit 83c1b2a
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 27 deletions.
6 changes: 4 additions & 2 deletions narwhals/_duckdb/expr_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from duckdb import FunctionExpression

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._duckdb.expr import DuckDBExpr


class DuckDBExprListNamespace:
def __init__(self, expr: DuckDBExpr) -> None:
def __init__(self: Self, expr: DuckDBExpr) -> None:
self._compliant_expr = expr

def len(self) -> DuckDBExpr:
def len(self: Self) -> DuckDBExpr:
return self._compliant_expr._from_call(
lambda _input: FunctionExpression("len", _input),
"len",
Expand Down
48 changes: 45 additions & 3 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from narwhals._spark_like.utils import ExprKind
from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import import_dtypes_module
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import validate_backend_version
Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(
validate_backend_version(self._implementation, self._backend_version)

@property
def _F(self) -> Any: # noqa: N802
def _F(self: Self) -> Any: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import functions

Expand All @@ -56,7 +58,7 @@ def _F(self) -> Any: # noqa: N802
return functions

@property
def _native_dtypes(self) -> Any:
def _native_dtypes(self: Self) -> Any:
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import types

Expand All @@ -66,7 +68,7 @@ def _native_dtypes(self) -> Any:
return types

@property
def _Window(self) -> Any: # noqa: N802
def _Window(self: Self) -> Any: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import Window

Expand Down Expand Up @@ -312,3 +314,43 @@ def join(
return self._from_native_frame(
self_native.join(other, on=left_on, how=how).select(col_order)
)

def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Self:
dtypes = import_dtypes_module(self._version)

to_explode = (
[columns, *more_columns]
if isinstance(columns, str)
else [*columns, *more_columns]
)
schema = self.collect_schema()
for col_to_explode in to_explode:
dtype = schema[col_to_explode]

if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)

native_frame = self._native_frame
column_names = self.columns

if len(to_explode) != 1:
msg = (
"Exploding on multiple columns is not supported with SparkLike backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)

return self._from_native_frame(
native_frame.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != to_explode[0]
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
)
)
5 changes: 5 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence

from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.utils import ExprKind
Expand Down Expand Up @@ -556,3 +557,7 @@ def name(self: Self) -> SparkLikeExprNameNamespace:
@property
def dt(self: Self) -> SparkLikeExprDateTimeNamespace:
return SparkLikeExprDateTimeNamespace(self)

@property
def list(self: Self) -> SparkLikeExprListNamespace:
return SparkLikeExprListNamespace(self)
20 changes: 20 additions & 0 deletions narwhals/_spark_like/expr_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import Self

from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeExprListNamespace:
def __init__(self: Self, expr: SparkLikeExpr) -> None:
self._compliant_expr = expr

def len(self: Self) -> SparkLikeExpr:
return self._compliant_expr._from_call(
self._compliant_expr._F.array_size,
"len",
expr_kind=self._compliant_expr._expr_kind,
)
29 changes: 16 additions & 13 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,26 +57,25 @@ def native_to_narwhals_dtype(
return dtypes.Int16()
if isinstance(dtype, spark_types.ByteType):
return dtypes.Int8()
string_types = [
spark_types.StringType,
spark_types.VarcharType,
spark_types.CharType,
]
if any(isinstance(dtype, t) for t in string_types):
if isinstance(
dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType)
):
return dtypes.String()
if isinstance(dtype, spark_types.BooleanType):
return dtypes.Boolean()
if isinstance(dtype, spark_types.DateType):
return dtypes.Date()
datetime_types = [
spark_types.TimestampType,
spark_types.TimestampNTZType,
]
if any(isinstance(dtype, t) for t in datetime_types):
if isinstance(dtype, (spark_types.TimestampType, spark_types.TimestampNTZType)):
return dtypes.Datetime()
if isinstance(dtype, spark_types.DecimalType): # pragma: no cover
# TODO(unassigned): cover this in dtypes_test.py
return dtypes.Decimal()
if isinstance(dtype, spark_types.ArrayType): # pragma: no cover
return dtypes.List(
inner=native_to_narwhals_dtype(
dtype.elementType, version=version, spark_types=spark_types
)
)
return dtypes.Unknown()


Expand Down Expand Up @@ -105,8 +104,12 @@ def narwhals_to_native_dtype(
msg = "Converting to Date or Datetime dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
msg = "Converting to List dtype is not supported yet"
raise NotImplementedError(msg)
inner = narwhals_to_native_dtype(
dtype.inner, # type: ignore[union-attr]
version=version,
spark_types=spark_types,
)
return spark_types.ArrayType(elementType=inner)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
raise NotImplementedError(msg)
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/list/len_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def test_len_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark")
):
if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
Expand Down
10 changes: 5 additions & 5 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_explode_single_col(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark")
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb")
):
request.applymarker(pytest.mark.xfail)

Expand Down Expand Up @@ -110,16 +110,16 @@ def test_explode_shape_error(
) -> None:
if any(
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark")
for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb")
):
request.applymarker(pytest.mark.xfail)

if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2):
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(ShapeError, PlShapeError),
match="exploded columns must have matching element counts",
(ShapeError, PlShapeError, NotImplementedError),
match=r".*exploded columns (must )?have matching element counts",
):
_ = (
nw.from_native(constructor(data))
Expand All @@ -133,7 +133,7 @@ def test_explode_shape_error(
def test_explode_invalid_operation_error(
request: pytest.FixtureRequest, constructor: Constructor
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb", "pyspark")):
if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")):
request.applymarker(pytest.mark.xfail)

if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6):
Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def _sort_dict_by_key(
) -> dict[str, list[Any]]: # pragma: no cover
sort_list = data_dict[key]
sorted_indices = sorted(
range(len(sort_list)), key=lambda i: (sort_list[i] is None, sort_list[i])
range(len(sort_list)),
key=lambda i: (
(sort_list[i] is None)
or (isinstance(sort_list[i], float) and math.isnan(sort_list[i])),
sort_list[i],
),
)
return {key: [value[i] for i in sorted_indices] for key, value in data_dict.items()}

Expand Down

0 comments on commit 83c1b2a

Please sign in to comment.