Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48344][SQL] Enhance SQL Script Execution: Replace NOOP with COLLECT for Result DataFrames #49372

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ case class SingleStatement(parsedPlan: LogicalPlan)
* @param label Label set to CompoundBody by user or UUID otherwise.
* It can be None in case when CompoundBody is not part of BeginEndCompoundBlock
* for example when CompoundBody is inside loop or conditional block.
* @param isScope Flag indicating if the CompoundBody is a labeled scope.
* Scopes are used for grouping local variables and exception handlers.
*/
case class CompoundBody(
collection: Seq[CompoundPlanStatement],
Expand Down
17 changes: 9 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,16 +448,17 @@ class SparkSession private(
val sse = new SqlScriptingExecution(script, this, args)
var result: Option[Seq[Row]] = None

while (sse.hasNext) {
// We must execute returned df before calling sse.getNextResult again because sse.hasNext
// advances the script execution and executes all statements until the next result. We must
// collect results immediately to maintain execution order.
// This ensures we respect the contract of SqlScriptingExecution API.
var df: Option[DataFrame] = sse.getNextResult
while (df.isDefined) {
sse.withErrorHandling {
val df = sse.next()
if (sse.hasNext) {
df.write.format("noop").mode("overwrite").save()
} else {
// Collect results from the last DataFrame.
result = Some(df.collect().toSeq)
}
// Collect results from the current DataFrame.
result = Some(df.get.collect().toSeq)
}
df = sse.getNextResult
}

if (result.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.scripting

import org.apache.spark.SparkException
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody}

/**
* SQL scripting executor - executes script and returns result statements.
* This supports returning multiple result statements from a single script.
* The caller of the SqlScriptingExecution API must adhere to the contract of executing
* the returned statement before continuing iteration. Executing the statement needs to be done
* inside withErrorHandling block.
*
* @param sqlScript CompoundBody which need to be executed.
* @param session Spark session that SQL script is executed within.
Expand All @@ -33,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody}
class SqlScriptingExecution(
sqlScript: CompoundBody,
session: SparkSession,
args: Map[String, Expression]) extends Iterator[DataFrame] {
args: Map[String, Expression]) {

private val interpreter = SqlScriptingInterpreter(session)

Expand All @@ -42,40 +44,40 @@ class SqlScriptingExecution(
val ctx = new SqlScriptingExecutionContext()
val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx)
// Add frame which represents SQL Script to the context.
ctx.frames.addOne(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator))
ctx.frames.append(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator))
// Enter the scope of the top level compound.
// We don't need to exit this scope explicitly as it will be done automatically
// when the frame is removed during iteration.
executionPlan.enterScope()
ctx
}

private var current: Option[DataFrame] = getNextResult

override def hasNext: Boolean = current.isDefined

override def next(): DataFrame = {
current match {
case None => throw SparkException.internalError("No more elements to iterate through.")
case Some(result) =>
current = getNextResult
result
}
}

/** Helper method to iterate get next statements from the first available frame. */
private def getNextStatement: Option[CompoundStatementExec] = {
// Remove frames that are already executed.
while (context.frames.nonEmpty && !context.frames.last.hasNext) {
context.frames.remove(context.frames.size - 1)
}
// If there are still frames available, get the next statement.
miland-db marked this conversation as resolved.
Show resolved Hide resolved
if (context.frames.nonEmpty) {
return Some(context.frames.last.next())
}
None
}

/** Helper method to iterate through statements until next result statement is encountered. */
private def getNextResult: Option[DataFrame] = {
/**
* Advances through the script and executes statements until a result statement or
* end of script is encountered.
*
* To know if there is result statement available, the method has to advance through script and
* execute statements until the result statement or end of script is encountered. For that reason
* the returned result must be executed before subsequent calls. Multiple calls without executing
* the intermediate results will lead to incorrect behavior.
*
* @return Result DataFrame if it is available, otherwise None.
*/
def getNextResult: Option[DataFrame] = {
var currentStatement = getNextStatement
// While we don't have a result statement, execute the statements.
while (currentStatement.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SqlScriptingExecutionFrame(
}

def enterScope(label: String): Unit = {
scopes.addOne(new SqlScriptingExecutionScope(label))
scopes.append(new SqlScriptingExecutionScope(label))
miland-db marked this conversation as resolved.
Show resolved Hide resolved
}

def exitScope(label: String): Unit = {
Expand All @@ -76,6 +76,7 @@ class SqlScriptingExecutionFrame(
scopes.remove(scopes.length - 1)
}

// Remove the scope with the given label.
if (scopes.nonEmpty) {
scopes.remove(scopes.length - 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ class NoOpStatementExec extends LeafStatementExec {
* @param label
* Label set by user to CompoundBody or None otherwise.
* @param isScope
* Flag that indicates whether Compound Body is scope or not.
* Flag indicating if the CompoundBody is a labeled scope.
* Scopes are used for grouping local variables and exception handlers.
* @param context
* SqlScriptingExecutionContext keeps the execution state of current script.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.scripting

import scala.collection.mutable.ListBuffer

import org.apache.spark.SparkConf
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.Expression
Expand All @@ -43,7 +45,15 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession {
args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = {
val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody]
val sse = new SqlScriptingExecution(compoundBody, spark, args)
sse.map { df => df.collect() }.toList
val result: ListBuffer[Array[Row]] = ListBuffer.empty

var df = sse.getNextResult
while (df.isDefined) {
// Collect results from the current DataFrame.
result.append(df.get.collect())
df = sse.getNextResult
}
result.toSeq
}

private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
// Initialize context so scopes can be entered correctly.
val context = new SqlScriptingExecutionContext()
val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context)
context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator))
context.frames.append(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator))
executionPlan.enterScope()

executionPlan.getTreeIterator.flatMap {
Expand Down
Loading