Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50762][SQL] Add Analyzer rule for resolving SQL scalar UDFs #49414

Conversation

allisonwang-db
Copy link
Contributor

@allisonwang-db allisonwang-db commented Jan 8, 2025

What changes were proposed in this pull request?

This PR adds a new Analyzer rule ResolveSQLFunctions to resolve scalar SQL UDFs by replacing a SQLFunctionExpression with an actual function body. It currently supports the following operators: Project, Filter, Join and Aggregate.

For example:

CREATE FUNCTION area(width DOUBLE, height DOUBLE) RETURNS DOUBLE
RETURN width * height;

and this query

SELECT area(a, b) FROM t;

will be resolved as

Project [area(width, height) AS area]
  +- Project [a, b, CAST(a AS DOUBLE) AS width, CAST(b AS DOUBLE) AS height]
    +- Relation [a, b]

Why are the changes needed?

To support SQL UDFs.

Does this PR introduce any user-facing change?

No

How was this patch tested?

New SQL query tests. More tests will be added once table function resolution is supported.

Was this patch authored or co-authored using generative AI tooling?

No

if !f.resolved || AggregateExpression.containsAggregate(cond) ||
ResolveGroupingAnalytics.hasGroupingFunction(cond) ||
cond.containsPattern(TEMP_RESOLVED_COLUMN) =>
// If the filter's condition contains aggregate expressions or grouping functions or temp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// If the filter's condition contains aggregate expressions or grouping functions or temp
// If the filter's condition contains aggregate expressions or grouping expressions or temp

Comment on lines 2580 to 2575
val topProject = if (topProjectList.nonEmpty) Project(topProjectList, newAgg) else newAgg
topProject
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
val topProject = if (topProjectList.nonEmpty) Project(topProjectList, newAgg) else newAgg
topProject
if (topProjectList.nonEmpty) Project(topProjectList, newAgg) else newAgg

* Example (aggregate):
* Before: foo(c1) + foo(max(c2)) + max(foo(c2))
* After: foo(c1) + foo(max_c2) + max_foo_c2
* Extracted expressions: [c1, max(c2) AS max_c2, max(foo(c2)) AS max_foo_c2]
Copy link
Contributor

@cloud-fan cloud-fan Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me of Aggregate normalization we did in RewriteWithExpression, which moves the result projection from Aggregate and puts it in a new Project node above Aggregate.

It's no harm to do this normalization but for safety we only do it when we have to, like the With expression and SQL UDF.

h.copy(child = a.copy(child = rewrite(a.child)))

case a: Aggregate if a.resolved && hasSQLFunctionExpression(a.expressions) =>
val child = rewrite(a.child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we rewrite SQL function top-down? Then the newly created Project under Aggregate can be rewritten in one pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we create a util function that only rewrites a single node, then we call it at the end of Aggregate rewriting to rewrite the newly created Project.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that's something that can be explored. I plan to add more tests in the upcoming PRs to make sure correctness first, and after that we can make further improvements.

@allisonwang-db allisonwang-db force-pushed the spark-50762-resolve-scalar-udf branch from b865df5 to 1b26abe Compare January 8, 2025 12:37
// Outer references also need to be wrapped because the function input may already
// contain outer references.
val outer = expr.transform {
case a: Attribute => OuterReference(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we use OuterReference if we always rewrite SQL UDF to scalar expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first step to resolve a SQL UDF is to verify the function body (expression or subquery) can be resolved correctly using the captured SQL config. We wrap the function inputs with outer references so that we can run simple analyzer on top:

Project [CAST(width * height AS DOUBLE) AS area]
  +- Project [CAST(outer(a) AS DOUBLE) AS width, CAST(outer(b AS DOUBLE) AS height]
    +- OneRowRelation

Once analyzed, the next step is to inline the SQL UDF body into the original query plan tree (rewriteSQLFunctions)

* A wrapper node for a SQL scalar function expression.
*/
case class SQLScalarFunction(function: SQLFunction, inputs: Seq[Expression], child: Expression)
extends UnaryExpression with UnaryLike[Expression] with Unevaluable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
extends UnaryExpression with UnaryLike[Expression] with Unevaluable {
extends UnaryExpression with Unevaluable {

*
* Analyzed plan:
*
* Project [foo(x) AS foo]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we eliminate the Aggregate and Filter?

// aggregate functions. They need to be extracted into a project list above the
// current aggregate.
val aggExprs = ArrayBuffer.empty[NamedExpression]
val topProjectList = aggregateExpressions.map(extractAndRewrite(_, aggExprs))
Copy link
Contributor

@cloud-fan cloud-fan Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can follow RewriteWithExpression to get the top project in a simpler way

val PhysicalAggregation(groupingExprs, aggExprs, resultExprs, _, _) = a
val newGroupingExprs = groupingExprs.map(rewriteSQLFunctions(_, bottomProjectList))
val newAggExprs = aggExprs.map(rewriteSQLFunctions(_, bottomProjectList))
...

Another issue is the group expression may appear in aggregateExpressions as well, and we want to avoid duplicating the SQL function expression. This can be solved by PullOutGroupingExpressions. We can create a util function in PullOutGroupingExpressions which rewrites a single Aggregate, then leverage it here. To put everything together:

val rewritten = PullOutGroupingExpressions.rewriteAgg(a)
val PhysicalAggregation(groupingExprs, aggExprs, resultExprs, _, _) = rewritten
val newAggExprs = aggExprs.map(rewriteSQLFunctions(_, bottomProjectList))
// no need to rewrite grouping expr as it won't contain SQL UDF now.
Project(resultExprs, rewritten.copy(
  aggExprs = newAggExprs,
  child = Project(bottomProjectList, rewritten.child))
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is a more aggressive rewrite, which rewrites all aggregate/grouping expressions with the same idea of rewriting SQL UDF. The code is simpler but the plan change is large. I'm also ok to keep the current implementation as it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. We can explore this once we have more test coverage.

/**
* Test suite for SQL user-defined functions (UDFs).
*/
class SQLFunctionSuite extends QueryTest with SharedSparkSession {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a duplication of golden file tests? It's also end-to-end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intended to test plan structures (for more complicated queries) and for other DDL commands in the future (such as DESCRIBE)

@allisonwang-db allisonwang-db force-pushed the spark-50762-resolve-scalar-udf branch from 73b650f to 46fb145 Compare January 13, 2025 21:51
@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in bba6839 Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants