diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala index 2b2a001c7..7341be593 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLAppendColITSuite.scala @@ -39,6 +39,8 @@ class FlintSparkPPLAppendColITSuite private val AGE_ALIAS = Alias(UnresolvedAttribute("age"), "age")() + private val COUNTRY_ALIAS = Alias(UnresolvedAttribute("country"), "country")() + private val RELATION_TEST_TABLE = UnresolvedRelation( Seq("spark_catalog", "default", "flint_ppl_test")) @@ -598,4 +600,91 @@ class FlintSparkPPLAppendColITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } + // @formatter:off + /** + * 'Project [*] + * +- 'DataFrameDropColumns ['APPENDCOL_T1._row_number_, 'APPENDCOL_T2._row_number_] + * +- 'Join FullOuter, ('APPENDCOL_T1._row_number_ = 'APPENDCOL_T2._row_number_) + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#977, *] + * : +- 'Project ['country, 'avg_age1] + * : +- 'Aggregate ['country AS country#975], ['AVG('age) AS avg_age1#974, 'country AS country#975] + * : +- 'UnresolvedRelation [testTable], [], false + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#981, *] + * +- 'Project ['avg_age2] + * +- 'Aggregate ['country AS country#979], ['AVG('age) AS avg_age2#978, 'country AS country#979] + * +- 'UnresolvedRelation [testTable], [], false + */ + // @formatter:on + test("test AppendCol with multiple stats commands") { + val frame = sql(s""" + | source = $testTable | stats avg(age) as avg_age1 by country | fields country, avg_age1 | appendcol [stats avg(age) as avg_age2 by country | fields avg_age2]; + | """.stripMargin) + + assert(frame.columns.sameElements(Array("country", "avg_age1", "avg_age2"))) + // Retrieve the results + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row("USA", 50.0, 50.0), + Row("Canada", 22.5, 22.5)) + // Compare the results + results.foreach(row => println(row)) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + + /* + * :- 'SubqueryAlias APPENDCOL_T1 + * : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#977, *] + * : +- 'Project ['country, 'avg_age1] + * : +- 'Aggregate ['country AS country#975], ['AVG('age) AS avg_age1#974, 'country AS country#975] + * : +- 'UnresolvedRelation [testTable], [], false + */ + val t1 = SubqueryAlias( + "APPENDCOL_T1", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("country"), UnresolvedAttribute("avg_age1")), + Aggregate(COUNTRY_ALIAS :: Nil, + Seq(Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("age")), isDistinct = false), + "avg_age1")(), COUNTRY_ALIAS), + RELATION_TEST_TABLE)))) + + + /* + * +- 'SubqueryAlias APPENDCOL_T2 + * +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#981, *] + * +- 'Project ['avg_age2] + * +- 'Aggregate ['country AS country#979], ['AVG('age) AS avg_age2#978, 'country AS country#979] + * +- 'UnresolvedRelation [testTable], [], false + */ + val t2 = SubqueryAlias( + "APPENDCOL_T2", + Project( + Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)), + Project( + Seq(UnresolvedAttribute("avg_age2")), + Aggregate(COUNTRY_ALIAS :: Nil, + Seq(Alias( + UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("age")), isDistinct = false), + "avg_age2")(), COUNTRY_ALIAS), + RELATION_TEST_TABLE)))) + + + val expectedPlan = Project( + Seq(UnresolvedStar(None)), + DataFrameDropColumns( + T12_COLUMNS_SEQ, + Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE))) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + }