From 3bdf0e8b024b4700f8b6adf443a033274369a1a5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 14 Feb 2025 13:11:15 +0000 Subject: [PATCH] refactor: replace `pa.scalar` w/ `lit` alias Aligns with the equivalent change for `duckdb` (https://github.com/narwhals-dev/narwhals/pull/1966/commits/5306ce2e31f2d547e9799bf3fa0fe4d2e0f1c12a) --- narwhals/_arrow/series.py | 37 ++++++++++++++++-------------------- narwhals/_arrow/series_dt.py | 10 +++++----- narwhals/_arrow/utils.py | 16 ++++++---------- 3 files changed, 27 insertions(+), 36 deletions(-) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 5ceddae3a..9759da182 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -22,6 +22,7 @@ from narwhals._arrow.utils import cast_for_truediv from narwhals._arrow.utils import chunked_array from narwhals._arrow.utils import floordiv_compat +from narwhals._arrow.utils import lit from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import pad_series @@ -246,14 +247,14 @@ def __truediv__(self: Self, other: Any) -> ArrowSeries[_ScalarT_co]: ser, other = broadcast_and_extract_native(self, other, self._backend_version) if not isinstance(other, (pa.Array, pa.ChunkedArray)): # scalar - other = pa.scalar(other) + other = lit(other) return self._from_native_series(pc.divide(*cast_for_truediv(ser, other))) def __rtruediv__(self: Self, other: Any) -> ArrowSeries[_ScalarT_co]: ser, right = broadcast_and_extract_native(self, other, self._backend_version) if not isinstance(right, (pa.Array, pa.ChunkedArray)): # scalar - right = pa.scalar(right) if not isinstance(right, pa.Scalar) else right + right = lit(right) if not isinstance(right, pa.Scalar) else right return self._from_native_series(pc.divide(*cast_for_truediv(right, ser))) def __mod__(self: Self, other: Any) -> ArrowSeries[_ScalarT_co]: @@ -371,9 +372,9 @@ def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: m = cast( "pc.NumericArray[Any]", pc.subtract(ser_not_null, pc.mean(ser_not_null)) ) - m2 = pc.mean(pc.power(m, pa.scalar(2))) - m3 = pc.mean(pc.power(m, pa.scalar(3))) - biased_population_skewness = pc.divide(m3, pc.power(m2, pa.scalar(1.5))) + m2 = pc.mean(pc.power(m, lit(2))) + m3 = pc.mean(pc.power(m, lit(3))) + biased_population_skewness = pc.divide(m3, pc.power(m2, lit(1.5))) return maybe_extract_py_scalar(biased_population_skewness, _return_py_scalar) def count(self: Self, *, _return_py_scalar: bool = True) -> int: @@ -693,10 +694,7 @@ def fill_aux( )[::-1] distance = valid_index - indices return pc.if_else( - pc.and_( - pc.is_null(arr), - pc.less_equal(distance, pa.scalar(limit)), - ), + pc.and_(pc.is_null(arr), pc.less_equal(distance, lit(limit))), arr.take(valid_index), arr, ) @@ -705,7 +703,7 @@ def fill_aux( dtype = ser.type if value is not None: - res_ser = self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype))) # type: ignore[attr-defined] + res_ser = self._from_native_series(pc.fill_null(ser, lit(value, dtype))) # type: ignore[attr-defined] elif limit is None: fill_func = ( pc.fill_null_forward if strategy == "forward" else pc.fill_null_backward @@ -1112,7 +1110,7 @@ def rank( rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker) - result = pc.if_else(null_mask, pa.scalar(None, native_series.type), rank) + result = pc.if_else(null_mask, lit(None, native_series.type), rank) return self._from_native_series(result) def hist( # noqa: PLR0915 @@ -1135,15 +1133,15 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa pad_lowest_bin = False pa_float = pa.type_for_alias("float") if lower == upper: - range_ = pa.scalar(1.0) - mid = pa.scalar(0.5) - width = pc.divide(range_, pa.scalar(bin_count)) + range_ = lit(1.0) + mid = lit(0.5) + width = pc.divide(range_, lit(bin_count)) lower = pc.subtract(lower, mid) upper = pc.add(upper, mid) else: pad_lowest_bin = True range_ = pc.subtract(upper, lower) - width = pc.divide(pc.cast(range_, pa_float), pa.scalar(float(bin_count))) + width = pc.divide(pc.cast(range_, pa_float), lit(float(bin_count))) bin_proportions = pc.divide( pc.subtract( @@ -1185,7 +1183,7 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa # empty bin intervals should have a 0 count counts_coalesce = cast( "pa.Array[Any]", - pc.coalesce(cast("pa.Array[Any]", counts.column("counts")), pa.scalar(0)), + pc.coalesce(cast("pa.Array[Any]", counts.column("counts")), lit(0)), ) counts = counts.set_column(0, "counts", counts_coalesce) @@ -1196,8 +1194,7 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa # pad lowest bin by 1% of range lowest_padded = [ pc.subtract( - bin_left[0], - pc.multiply(pc.cast(range_, pa_float), pa.scalar(0.001)), + bin_left[0], pc.multiply(pc.cast(range_, pa_float), lit(0.001)) ) ] bin_left = chunked_array([lowest_padded, cast("Any", bin_left[1:])]) @@ -1257,9 +1254,7 @@ def __contains__(self: Self, other: Any) -> bool: try: native_series = self._native_series other_ = ( - pa.scalar(other) - if other is not None - else pa.scalar(None, type=native_series.type) + lit(other) if other is not None else lit(None, type=native_series.type) ) return maybe_extract_py_scalar( pc.is_in(other_, native_series), return_py_scalar=True diff --git a/narwhals/_arrow/series_dt.py b/narwhals/_arrow/series_dt.py index d8f463826..ef51fff90 100644 --- a/narwhals/_arrow/series_dt.py +++ b/narwhals/_arrow/series_dt.py @@ -190,7 +190,7 @@ def total_minutes(self: Self) -> ArrowSeries[pa.Int64Scalar]: "us": 60 * 1e6, # micro "ns": 60 * 1e9, # nano } - factor = pa.scalar(unit_to_minutes_factor[ser._type.unit], type=pa.int64()) + factor = lit(unit_to_minutes_factor[ser._type.unit], type=pa.int64()) return self._compliant_series._from_native_series( pc.cast(pc.divide(ser._native_series, factor), pa.int64()) ) @@ -203,7 +203,7 @@ def total_seconds(self: Self) -> ArrowSeries[pa.Int64Scalar]: "us": 1e6, # micro "ns": 1e9, # nano } - factor = pa.scalar(unit_to_seconds_factor[ser._type.unit], type=pa.int64()) + factor = lit(unit_to_seconds_factor[ser._type.unit], type=pa.int64()) return self._compliant_series._from_native_series( pc.cast(pc.divide(ser._native_series, factor), pa.int64()) ) @@ -218,7 +218,7 @@ def total_milliseconds(self: Self) -> ArrowSeries[pa.Int64Scalar]: "us": 1e3, # micro "ns": 1e6, # nano } - factor = pa.scalar(unit_to_milli_factor[unit], type=pa.int64()) + factor = lit(unit_to_milli_factor[unit], type=pa.int64()) if unit == "s": return self._compliant_series._from_native_series( pc.cast(pc.multiply(arr, factor), pa.int64()) @@ -237,7 +237,7 @@ def total_microseconds(self: Self) -> ArrowSeries[pa.Int64Scalar]: "us": 1, # micro "ns": 1e3, # nano } - factor = pa.scalar(unit_to_micro_factor[unit], type=pa.int64()) + factor = lit(unit_to_micro_factor[unit], type=pa.int64()) if unit in {"s", "ms"}: return self._compliant_series._from_native_series( pc.cast(pc.multiply(arr, factor), pa.int64()) @@ -254,7 +254,7 @@ def total_nanoseconds(self: Self) -> ArrowSeries[pa.Int64Scalar]: "us": 1e3, # micro "ns": 1, # nano } - factor = pa.scalar(unit_to_nano_factor[ser._type.unit], type=pa.int64()) + factor = lit(unit_to_nano_factor[ser._type.unit], type=pa.int64()) return self._compliant_series._from_native_series( pc.cast(pc.multiply(ser._native_series, factor), pa.int64()) ) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index c67581023..4431a8a09 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -233,7 +233,7 @@ def broadcast_and_extract_native( from narwhals._arrow.series import ArrowSeries if rhs is None: # DONE - return lhs._native_series, pa.scalar(None, type=lhs._native_series.type) + return lhs._native_series, lit(None, type=lhs._native_series.type) # If `rhs` is the output of an expression evaluation, then it is # a list of Series. So, we verify that that list is of length-1, @@ -352,10 +352,10 @@ def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 if isinstance(left, (int, float)): - left = pa.scalar(left) + left = lit(left) if isinstance(right, (int, float)): - right = pa.scalar(right) + right = lit(right) if pa.types.is_integer(left.type) and pa.types.is_integer(right.type): divided = pc.divide_checked(left, right) @@ -363,16 +363,12 @@ def floordiv_compat(left: Any, right: Any) -> Any: # GH 56676 has_remainder = pc.not_equal(pc.multiply(divided, right), left) has_one_negative_operand = pc.less( - pc.bit_wise_xor(left, right), - pa.scalar(0, type=divided.type), + pc.bit_wise_xor(left, right), lit(0, type=divided.type) ) result = pc.if_else( - pc.and_( - has_remainder, - has_one_negative_operand, - ), + pc.and_(has_remainder, has_one_negative_operand), # GH: 55561 ruff: ignore - pc.subtract(divided, pa.scalar(1, type=divided.type)), + pc.subtract(divided, lit(1, type=divided.type)), divided, ) else: