Skip to content

Commit

Permalink
Multiple avg commands
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Kwok <[email protected]>
  • Loading branch information
andy-k-improving committed Jan 3, 2025
1 parent 0c8e3d1 commit 820df4e
Showing 1 changed file with 89 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class FlintSparkPPLAppendColITSuite

private val AGE_ALIAS = Alias(UnresolvedAttribute("age"), "age")()

private val COUNTRY_ALIAS = Alias(UnresolvedAttribute("country"), "country")()

private val RELATION_TEST_TABLE = UnresolvedRelation(
Seq("spark_catalog", "default", "flint_ppl_test"))

Expand Down Expand Up @@ -598,4 +600,91 @@ class FlintSparkPPLAppendColITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

// @formatter:off
/**
* 'Project [*]
* +- 'DataFrameDropColumns ['APPENDCOL_T1._row_number_, 'APPENDCOL_T2._row_number_]
* +- 'Join FullOuter, ('APPENDCOL_T1._row_number_ = 'APPENDCOL_T2._row_number_)
* :- 'SubqueryAlias APPENDCOL_T1
* : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#977, *]
* : +- 'Project ['country, 'avg_age1]
* : +- 'Aggregate ['country AS country#975], ['AVG('age) AS avg_age1#974, 'country AS country#975]
* : +- 'UnresolvedRelation [testTable], [], false
* +- 'SubqueryAlias APPENDCOL_T2
* +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#981, *]
* +- 'Project ['avg_age2]
* +- 'Aggregate ['country AS country#979], ['AVG('age) AS avg_age2#978, 'country AS country#979]
* +- 'UnresolvedRelation [testTable], [], false
*/
// @formatter:on
test("test AppendCol with multiple stats commands") {
val frame = sql(s"""
| source = $testTable | stats avg(age) as avg_age1 by country | fields country, avg_age1 | appendcol [stats avg(age) as avg_age2 by country | fields avg_age2];
| """.stripMargin)

assert(frame.columns.sameElements(Array("country", "avg_age1", "avg_age2")))
// Retrieve the results
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row("USA", 50.0, 50.0),
Row("Canada", 22.5, 22.5))
// Compare the results
results.foreach(row => println(row))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical


/*
* :- 'SubqueryAlias APPENDCOL_T1
* : +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#977, *]
* : +- 'Project ['country, 'avg_age1]
* : +- 'Aggregate ['country AS country#975], ['AVG('age) AS avg_age1#974, 'country AS country#975]
* : +- 'UnresolvedRelation [testTable], [], false
*/
val t1 = SubqueryAlias(
"APPENDCOL_T1",
Project(
Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)),
Project(
Seq(UnresolvedAttribute("country"), UnresolvedAttribute("avg_age1")),
Aggregate(COUNTRY_ALIAS :: Nil,
Seq(Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("age")), isDistinct = false),
"avg_age1")(), COUNTRY_ALIAS),
RELATION_TEST_TABLE))))


/*
* +- 'SubqueryAlias APPENDCOL_T2
* +- 'Project [row_number() windowspecdefinition(1 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_#981, *]
* +- 'Project ['avg_age2]
* +- 'Aggregate ['country AS country#979], ['AVG('age) AS avg_age2#978, 'country AS country#979]
* +- 'UnresolvedRelation [testTable], [], false
*/
val t2 = SubqueryAlias(
"APPENDCOL_T2",
Project(
Seq(ROW_NUMBER_AGGREGATION, UnresolvedStar(None)),
Project(
Seq(UnresolvedAttribute("avg_age2")),
Aggregate(COUNTRY_ALIAS :: Nil,
Seq(Alias(
UnresolvedFunction(Seq("AVG"), Seq(UnresolvedAttribute("age")), isDistinct = false),
"avg_age2")(), COUNTRY_ALIAS),
RELATION_TEST_TABLE))))


val expectedPlan = Project(
Seq(UnresolvedStar(None)),
DataFrameDropColumns(
T12_COLUMNS_SEQ,
Join(t1, t2, FullOuter, Some(T12_JOIN_CONDITION), JoinHint.NONE)))

comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

}

0 comments on commit 820df4e

Please sign in to comment.