diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala index 596626698..e10b2e2a6 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLEvalITSuite.scala @@ -10,7 +10,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLEvalITSuite @@ -22,6 +22,7 @@ class FlintSparkPPLEvalITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log" + private val duplicatesNullableTestTable = "spark_catalog.default.duplicates_nullable_test" override def beforeAll(): Unit = { super.beforeAll() @@ -29,6 +30,7 @@ class FlintSparkPPLEvalITSuite // Create test table createPartitionedStateCountryTable(testTable) createTableHttpLog(testTableHttpLog) + createDuplicationNullableTable(duplicatesNullableTestTable) } protected override def afterEach(): Unit = { @@ -632,8 +634,45 @@ class FlintSparkPPLEvalITSuite EqualTo(Literal(true), and) } - // Todo excluded fields not support yet + test("Test eval and signum function") { + val frame = sql(s""" + | source = $duplicatesNullableTestTable | fields id | sort id | eval i = pow(-2, id), s = signum(i) | head 5 + | """.stripMargin) + val rows = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1, -2d, -1d), + Row(2, 4d, 1d), + Row(3, -8d, -1d), + Row(4, 16d, 1d), + Row(5, -32d, -1d)) + assert(rows.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val tablePlan = + UnresolvedRelation(Seq("spark_catalog", "default", "duplicates_nullable_test")) + val projectIdPlan = Project(Seq(UnresolvedAttribute("id")), tablePlan) + val sortPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), global = true, projectIdPlan) + val evalPlan = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "pow", + Seq(Literal(-2), UnresolvedAttribute("id")), + isDistinct = false), + "i")(), + Alias( + UnresolvedFunction("signum", Seq(UnresolvedAttribute("i")), isDistinct = false), + "s")()), + sortPlan) + val localLimitPlan = LocalLimit(Literal(5), evalPlan) + val globalLimitPlan = GlobalLimit(Literal(5), localLimitPlan) + val expectedPlan = Project(Seq(UnresolvedStar(None)), globalLimitPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + // Todo excluded fields not support yet ignore("test single eval expression with excluded fields") { val frame = sql(s""" | source = $testTable | eval new_field = "New Field" | fields - age diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index f07fcbd3f..6b3996f52 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -333,6 +333,7 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | eval a = 10 | fields a,b,c` - `source = table | eval a = a * 2 | stats avg(a)` - `source = table | eval a = abs(a) | where a > 0` + - `source = table | eval a = signum(a) | where a < 0` **Aggregations** - `source = table | stats avg(a) ` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 60e1e9922..7af3e2109 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -259,6 +259,7 @@ POWER: 'POWER'; RAND: 'RAND'; ROUND: 'ROUND'; SIGN: 'SIGN'; +SIGNUM: 'SIGNUM'; SQRT: 'SQRT'; TRUNCATE: 'TRUNCATE'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 626ff2165..385c871cb 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -568,6 +568,7 @@ mathematicalFunctionName | RAND | ROUND | SIGN + | SIGNUM | SQRT | TRUNCATE | trigonometricFunctionName diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala index ed72a3d40..feaa7d8ca 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanMathFunctionsTranslatorTestSuite.scala @@ -191,4 +191,18 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite val expectedPlan = Project(projectList, evalProject) comparePlans(expectedPlan, logPlan, false) } + + test("test signum") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit(plan(pplParser, "source=t a = signum(b)"), context) + + val table = UnresolvedRelation(Seq("t")) + val filterExpr = EqualTo( + UnresolvedAttribute("a"), + UnresolvedFunction("signum", seq(UnresolvedAttribute("b")), isDistinct = false)) + val filterPlan = Filter(filterExpr, table) + val projectList = Seq(UnresolvedStar(None)) + val expectedPlan = Project(projectList, filterPlan) + comparePlans(expectedPlan, logPlan, false) + } }