Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support array_compact function #1321

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,25 @@ impl PhysicalPlanner {
));
Ok(array_has_any_expr)
}
ExprStruct::ArrayCompact(expr) => {
let src_array_expr =
self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?;
let datatype = to_arrow_datatype(expr.item_datatype.as_ref().unwrap());

let null_literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null.cast_to(&datatype)?));
let args = vec![Arc::clone(&src_array_expr), null_literal_expr];
let return_type = src_array_expr.data_type(&input_schema)?;

let array_compact_expr = Arc::new(ScalarFunctionExpr::new(
"array_compact",
array_remove_all_udf(),
args,
return_type,
));

Ok(array_compact_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
6 changes: 6 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ message Expr {
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
ArrayCompact array_compact = 65;
}
}

Expand Down Expand Up @@ -423,6 +424,11 @@ message ArrayJoin {
Expr null_replacement_expr = 3;
}

message ArrayCompact {
Expr array_expr = 1;
DataType item_datatype = 2;
}

message DataType {
enum DataTypeId {
BOOL = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2366,6 +2366,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
case _: ArrayIntersect => convert(CometArrayIntersect)
case _: ArrayJoin => convert(CometArrayJoin)
case _: ArraysOverlap => convert(CometArraysOverlap)
case expr @ ArrayFilter(child, _) if ArrayCompact(child).replacement.sql == expr.sql =>
kazantsev-maksim marked this conversation as resolved.
Show resolved Hide resolved
convert(CometArrayCompact)
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
27 changes: 26 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove, Attrib
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto}
import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto, serializeDataType}
import org.apache.comet.shims.CometExprShim

object CometArrayRemove extends CometExpressionSerde with CometExprShim {
Expand Down Expand Up @@ -126,6 +126,31 @@ object CometArraysOverlap extends CometExpressionSerde with IncompatExpr {
}
}

object CometArrayCompact extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val child = expr.children.head
val elementType = serializeDataType(child.dataType.asInstanceOf[ArrayType].elementType)
val srcExprProto = exprToProto(child, inputs, binding)
if (elementType.isDefined && srcExprProto.isDefined) {
val arrayCompactBuilder = ExprOuterClass.ArrayCompact
.newBuilder()
.setArrayExpr(srcExprProto.get)
.setItemDatatype(elementType.get)
Some(
ExprOuterClass.Expr
.newBuilder()
.setArrayCompact(arrayCompactBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for ArrayCompact", expr.children: _*)
None
}
}
}

object CometArrayJoin extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,4 +292,24 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_compact") {
assume(isSpark34Plus)
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NULL"))
checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2)) FROM t1 WHERE _2 IS NOT NULL"))
checkSparkAnswerAndOperator(
sql("SELECT array_compact(array(_2, _3, null)) FROM t1 WHERE _2 IS NOT NULL"))
}
}
}
}

}
Loading