Skip to content

Commit

Permalink
Antler syntax extensions related ot signum function. (opensearch-proj…
Browse files Browse the repository at this point in the history
…ect#652)

* Antler syntax extensions related ot signum function.

Signed-off-by: Lukasz Soszynski <[email protected]>

* Test and documentation related to signum function.

Signed-off-by: Lukasz Soszynski <[email protected]>

* Integration test for usage signum function in the eval command

Signed-off-by: Lukasz Soszynski <[email protected]>

---------

Signed-off-by: Lukasz Soszynski <[email protected]>
  • Loading branch information
lukasz-soszynski-eliatra authored Oct 1, 2024
1 parent 0aeed27 commit 3048368
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLEvalITSuite
Expand All @@ -22,13 +22,15 @@ class FlintSparkPPLEvalITSuite
/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log"
private val duplicatesNullableTestTable = "spark_catalog.default.duplicates_nullable_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createTableHttpLog(testTableHttpLog)
createDuplicationNullableTable(duplicatesNullableTestTable)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -632,8 +634,45 @@ class FlintSparkPPLEvalITSuite
EqualTo(Literal(true), and)
}

// Todo excluded fields not support yet
test("Test eval and signum function") {
val frame = sql(s"""
| source = $duplicatesNullableTestTable | fields id | sort id | eval i = pow(-2, id), s = signum(i) | head 5
| """.stripMargin)
val rows = frame.collect()
val expectedResults: Array[Row] = Array(
Row(1, -2d, -1d),
Row(2, 4d, 1d),
Row(3, -8d, -1d),
Row(4, 16d, 1d),
Row(5, -32d, -1d))
assert(rows.sameElements(expectedResults))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val tablePlan =
UnresolvedRelation(Seq("spark_catalog", "default", "duplicates_nullable_test"))
val projectIdPlan = Project(Seq(UnresolvedAttribute("id")), tablePlan)
val sortPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("id"), Ascending)), global = true, projectIdPlan)
val evalPlan = Project(
Seq(
UnresolvedStar(None),
Alias(
UnresolvedFunction(
"pow",
Seq(Literal(-2), UnresolvedAttribute("id")),
isDistinct = false),
"i")(),
Alias(
UnresolvedFunction("signum", Seq(UnresolvedAttribute("i")), isDistinct = false),
"s")()),
sortPlan)
val localLimitPlan = LocalLimit(Literal(5), evalPlan)
val globalLimitPlan = GlobalLimit(Literal(5), localLimitPlan)
val expectedPlan = Project(Seq(UnresolvedStar(None)), globalLimitPlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

// Todo excluded fields not support yet
ignore("test single eval expression with excluded fields") {
val frame = sql(s"""
| source = $testTable | eval new_field = "New Field" | fields - age
Expand Down
1 change: 1 addition & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ Limitation: Overriding existing field is unsupported, following queries throw ex
- `source = table | eval a = 10 | fields a,b,c`
- `source = table | eval a = a * 2 | stats avg(a)`
- `source = table | eval a = abs(a) | where a > 0`
- `source = table | eval a = signum(a) | where a < 0`

**Aggregations**
- `source = table | stats avg(a) `
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ POWER: 'POWER';
RAND: 'RAND';
ROUND: 'ROUND';
SIGN: 'SIGN';
SIGNUM: 'SIGNUM';
SQRT: 'SQRT';
TRUNCATE: 'TRUNCATE';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ mathematicalFunctionName
| RAND
| ROUND
| SIGN
| SIGNUM
| SQRT
| TRUNCATE
| trigonometricFunctionName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,18 @@ class PPLLogicalPlanMathFunctionsTranslatorTestSuite
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}

test("test signum") {
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(plan(pplParser, "source=t a = signum(b)"), context)

val table = UnresolvedRelation(Seq("t"))
val filterExpr = EqualTo(
UnresolvedAttribute("a"),
UnresolvedFunction("signum", seq(UnresolvedAttribute("b")), isDistinct = false))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
comparePlans(expectedPlan, logPlan, false)
}
}

0 comments on commit 3048368

Please sign in to comment.