Skip to content

Commit

Permalink
refactor: replace pa.scalar w/ lit alias
Browse files Browse the repository at this point in the history
Aligns with the equivalent change for `duckdb` (5306ce2)
  • Loading branch information
dangotbanned committed Feb 14, 2025
1 parent cf27f7d commit 3bdf0e8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 36 deletions.
37 changes: 16 additions & 21 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from narwhals._arrow.utils import cast_for_truediv
from narwhals._arrow.utils import chunked_array
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import lit
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import pad_series
Expand Down Expand Up @@ -246,14 +247,14 @@ def __truediv__(self: Self, other: Any) -> ArrowSeries[_ScalarT_co]:
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
other = lit(other)
return self._from_native_series(pc.divide(*cast_for_truediv(ser, other)))

def __rtruediv__(self: Self, other: Any) -> ArrowSeries[_ScalarT_co]:
ser, right = broadcast_and_extract_native(self, other, self._backend_version)
if not isinstance(right, (pa.Array, pa.ChunkedArray)):
# scalar
right = pa.scalar(right) if not isinstance(right, pa.Scalar) else right
right = lit(right) if not isinstance(right, pa.Scalar) else right
return self._from_native_series(pc.divide(*cast_for_truediv(right, ser)))

def __mod__(self: Self, other: Any) -> ArrowSeries[_ScalarT_co]:
Expand Down Expand Up @@ -371,9 +372,9 @@ def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None:
m = cast(
"pc.NumericArray[Any]", pc.subtract(ser_not_null, pc.mean(ser_not_null))
)
m2 = pc.mean(pc.power(m, pa.scalar(2)))
m3 = pc.mean(pc.power(m, pa.scalar(3)))
biased_population_skewness = pc.divide(m3, pc.power(m2, pa.scalar(1.5)))
m2 = pc.mean(pc.power(m, lit(2)))
m3 = pc.mean(pc.power(m, lit(3)))
biased_population_skewness = pc.divide(m3, pc.power(m2, lit(1.5)))
return maybe_extract_py_scalar(biased_population_skewness, _return_py_scalar)

def count(self: Self, *, _return_py_scalar: bool = True) -> int:
Expand Down Expand Up @@ -693,10 +694,7 @@ def fill_aux(
)[::-1]
distance = valid_index - indices
return pc.if_else(
pc.and_(
pc.is_null(arr),
pc.less_equal(distance, pa.scalar(limit)),
),
pc.and_(pc.is_null(arr), pc.less_equal(distance, lit(limit))),
arr.take(valid_index),
arr,
)
Expand All @@ -705,7 +703,7 @@ def fill_aux(
dtype = ser.type

if value is not None:
res_ser = self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype))) # type: ignore[attr-defined]
res_ser = self._from_native_series(pc.fill_null(ser, lit(value, dtype))) # type: ignore[attr-defined]
elif limit is None:
fill_func = (
pc.fill_null_forward if strategy == "forward" else pc.fill_null_backward
Expand Down Expand Up @@ -1112,7 +1110,7 @@ def rank(

rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker)

result = pc.if_else(null_mask, pa.scalar(None, native_series.type), rank)
result = pc.if_else(null_mask, lit(None, native_series.type), rank)
return self._from_native_series(result)

def hist( # noqa: PLR0915
Expand All @@ -1135,15 +1133,15 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa
pad_lowest_bin = False
pa_float = pa.type_for_alias("float")
if lower == upper:
range_ = pa.scalar(1.0)
mid = pa.scalar(0.5)
width = pc.divide(range_, pa.scalar(bin_count))
range_ = lit(1.0)
mid = lit(0.5)
width = pc.divide(range_, lit(bin_count))
lower = pc.subtract(lower, mid)
upper = pc.add(upper, mid)
else:
pad_lowest_bin = True
range_ = pc.subtract(upper, lower)
width = pc.divide(pc.cast(range_, pa_float), pa.scalar(float(bin_count)))
width = pc.divide(pc.cast(range_, pa_float), lit(float(bin_count)))

bin_proportions = pc.divide(
pc.subtract(
Expand Down Expand Up @@ -1185,7 +1183,7 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa
# empty bin intervals should have a 0 count
counts_coalesce = cast(
"pa.Array[Any]",
pc.coalesce(cast("pa.Array[Any]", counts.column("counts")), pa.scalar(0)),
pc.coalesce(cast("pa.Array[Any]", counts.column("counts")), lit(0)),
)
counts = counts.set_column(0, "counts", counts_coalesce)

Expand All @@ -1196,8 +1194,7 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa
# pad lowest bin by 1% of range
lowest_padded = [
pc.subtract(
bin_left[0],
pc.multiply(pc.cast(range_, pa_float), pa.scalar(0.001)),
bin_left[0], pc.multiply(pc.cast(range_, pa_float), lit(0.001))
)
]
bin_left = chunked_array([lowest_padded, cast("Any", bin_left[1:])])
Expand Down Expand Up @@ -1257,9 +1254,7 @@ def __contains__(self: Self, other: Any) -> bool:
try:
native_series = self._native_series
other_ = (
pa.scalar(other)
if other is not None
else pa.scalar(None, type=native_series.type)
lit(other) if other is not None else lit(None, type=native_series.type)
)
return maybe_extract_py_scalar(
pc.is_in(other_, native_series), return_py_scalar=True
Expand Down
10 changes: 5 additions & 5 deletions narwhals/_arrow/series_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def total_minutes(self: Self) -> ArrowSeries[pa.Int64Scalar]:
"us": 60 * 1e6, # micro
"ns": 60 * 1e9, # nano
}
factor = pa.scalar(unit_to_minutes_factor[ser._type.unit], type=pa.int64())
factor = lit(unit_to_minutes_factor[ser._type.unit], type=pa.int64())
return self._compliant_series._from_native_series(
pc.cast(pc.divide(ser._native_series, factor), pa.int64())
)
Expand All @@ -203,7 +203,7 @@ def total_seconds(self: Self) -> ArrowSeries[pa.Int64Scalar]:
"us": 1e6, # micro
"ns": 1e9, # nano
}
factor = pa.scalar(unit_to_seconds_factor[ser._type.unit], type=pa.int64())
factor = lit(unit_to_seconds_factor[ser._type.unit], type=pa.int64())
return self._compliant_series._from_native_series(
pc.cast(pc.divide(ser._native_series, factor), pa.int64())
)
Expand All @@ -218,7 +218,7 @@ def total_milliseconds(self: Self) -> ArrowSeries[pa.Int64Scalar]:
"us": 1e3, # micro
"ns": 1e6, # nano
}
factor = pa.scalar(unit_to_milli_factor[unit], type=pa.int64())
factor = lit(unit_to_milli_factor[unit], type=pa.int64())
if unit == "s":
return self._compliant_series._from_native_series(
pc.cast(pc.multiply(arr, factor), pa.int64())
Expand All @@ -237,7 +237,7 @@ def total_microseconds(self: Self) -> ArrowSeries[pa.Int64Scalar]:
"us": 1, # micro
"ns": 1e3, # nano
}
factor = pa.scalar(unit_to_micro_factor[unit], type=pa.int64())
factor = lit(unit_to_micro_factor[unit], type=pa.int64())
if unit in {"s", "ms"}:
return self._compliant_series._from_native_series(
pc.cast(pc.multiply(arr, factor), pa.int64())
Expand All @@ -254,7 +254,7 @@ def total_nanoseconds(self: Self) -> ArrowSeries[pa.Int64Scalar]:
"us": 1e3, # micro
"ns": 1, # nano
}
factor = pa.scalar(unit_to_nano_factor[ser._type.unit], type=pa.int64())
factor = lit(unit_to_nano_factor[ser._type.unit], type=pa.int64())
return self._compliant_series._from_native_series(
pc.cast(pc.multiply(ser._native_series, factor), pa.int64())
)
16 changes: 6 additions & 10 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def broadcast_and_extract_native(
from narwhals._arrow.series import ArrowSeries

if rhs is None: # DONE
return lhs._native_series, pa.scalar(None, type=lhs._native_series.type)
return lhs._native_series, lit(None, type=lhs._native_series.type)

# If `rhs` is the output of an expression evaluation, then it is
# a list of Series. So, we verify that that list is of length-1,
Expand Down Expand Up @@ -352,27 +352,23 @@ def floordiv_compat(left: Any, right: Any) -> Any:
# The following lines are adapted from pandas' pyarrow implementation.
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
if isinstance(left, (int, float)):
left = pa.scalar(left)
left = lit(left)

if isinstance(right, (int, float)):
right = pa.scalar(right)
right = lit(right)

if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
divided = pc.divide_checked(left, right)
if pa.types.is_signed_integer(divided.type):
# GH 56676
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
has_one_negative_operand = pc.less(
pc.bit_wise_xor(left, right),
pa.scalar(0, type=divided.type),
pc.bit_wise_xor(left, right), lit(0, type=divided.type)
)
result = pc.if_else(
pc.and_(
has_remainder,
has_one_negative_operand,
),
pc.and_(has_remainder, has_one_negative_operand),
# GH: 55561 ruff: ignore
pc.subtract(divided, pa.scalar(1, type=divided.type)),
pc.subtract(divided, lit(1, type=divided.type)),
divided,
)
else:
Expand Down

0 comments on commit 3bdf0e8

Please sign in to comment.