diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index c4baf94d6723..381649af3347 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sqlglot as sg import sqlglot.expressions as sge import ibis.common.exceptions as com @@ -40,25 +41,23 @@ def visit_DateNow(self, op): return self.cast(sge.CurrentTimestamp(), dt.date) def visit_First(self, op, *, arg, where, order_by, include_null): - if include_null: - raise com.UnsupportedOperationError( - "`include_null=True` is not supported by the risingwave backend" - ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `first`" ) + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.first_value(arg, where=where, order_by=order_by) def visit_Last(self, op, *, arg, where, order_by, include_null): - if include_null: - raise com.UnsupportedOperationError( - "`include_null=True` is not supported by the risingwave backend" - ) if not order_by: raise com.UnsupportedOperationError( "RisingWave requires an `order_by` be specified in `last`" ) + if not include_null: + cond = arg.is_(sg.not_(NULL, copy=False)) + where = cond if where is None else sge.And(this=cond, expression=where) return self.agg.last_value(arg, where=where, order_by=order_by) def visit_Correlation(self, op, *, left, right, how, where): diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index b7fd1fe72492..3b86de2b1eb6 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -647,23 +647,16 @@ def test_first_last(alltypes, method, filtered, include_null): raises=com.OperationNotDefinedError, ) @pytest.mark.parametrize("method", ["first", "last"]) -@pytest.mark.parametrize("filtered", [False, True]) +@pytest.mark.parametrize("filtered", [False, True], ids=["not-filtered", "filtered"]) @pytest.mark.parametrize( "include_null", [ - False, + param(False, id="exclude-null"), param( True, marks=[ pytest.mark.notimpl( - [ - "clickhouse", - "exasol", - "flink", - "postgres", - "risingwave", - "snowflake", - ], + ["clickhouse", "exasol", "flink", "postgres", "snowflake"], raises=com.UnsupportedOperationError, reason="`include_null=True` is not supported", ), @@ -674,6 +667,7 @@ def test_first_last(alltypes, method, filtered, include_null): strict=False, ), ], + id="include-null", ), ], )