diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 71ec5f334..ed727d021 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -834,6 +834,25 @@ impl PhysicalPlanner { )); Ok(array_has_any_expr) } + ExprStruct::ArrayCompact(expr) => { + let src_array_expr = + self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?; + let datatype = to_arrow_datatype(expr.item_datatype.as_ref().unwrap()); + + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null.cast_to(&datatype)?)); + let args = vec![Arc::clone(&src_array_expr), null_literal_expr]; + let return_type = src_array_expr.data_type(&input_schema)?; + + let array_compact_expr = Arc::new(ScalarFunctionExpr::new( + "array_compact", + array_remove_all_udf(), + args, + return_type, + )); + + Ok(array_compact_expr) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index fd928fd8a..1a10a9f01 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -89,6 +89,7 @@ message Expr { BinaryExpr array_intersect = 62; ArrayJoin array_join = 63; BinaryExpr arrays_overlap = 64; + ArrayCompact array_compact = 65; } } @@ -423,6 +424,11 @@ message ArrayJoin { Expr null_replacement_expr = 3; } +message ArrayCompact { + Expr array_expr = 1; + DataType item_datatype = 2; +} + message DataType { enum DataTypeId { BOOL = 0; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f4699af8d..808e7a2a2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2366,6 +2366,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: ArrayIntersect => convert(CometArrayIntersect) case _: ArrayJoin => convert(CometArrayJoin) case _: ArraysOverlap => convert(CometArraysOverlap) + case expr @ ArrayFilter(child, _) if ArrayCompact(child).replacement.sql == expr.sql => + convert(CometArrayCompact) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index db1679f22..5108ea849 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove, Attrib import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType} import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto} +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto, serializeDataType} import org.apache.comet.shims.CometExprShim object CometArrayRemove extends CometExpressionSerde with CometExprShim { @@ -126,6 +126,31 @@ object CometArraysOverlap extends CometExpressionSerde with IncompatExpr { } } +object CometArrayCompact extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val child = expr.children.head + val elementType = serializeDataType(child.dataType.asInstanceOf[ArrayType].elementType) + val srcExprProto = exprToProto(child, inputs, binding) + if (elementType.isDefined && srcExprProto.isDefined) { + val arrayCompactBuilder = ExprOuterClass.ArrayCompact + .newBuilder() + .setArrayExpr(srcExprProto.get) + .setItemDatatype(elementType.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayCompact(arrayCompactBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for ArrayCompact", expr.children: _*) + None + } + } +} + object CometArrayJoin extends CometExpressionSerde with IncompatExpr { override def convert( expr: Expression, diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index df1fccb69..8850f2133 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -292,4 +292,24 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_compact") { + assume(isSpark34Plus) + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + checkSparkAnswerAndOperator( + sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL")) + checkSparkAnswerAndOperator( + sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL")) + checkSparkAnswerAndOperator( + sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2 IS NOT NULL")) + } + } + } + } + }