From b3fdd8c151558edaa7f4eff183262ffe67e0f434 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Tue, 24 Dec 2024 18:34:35 +0100 Subject: [PATCH 1/5] Minor improvements --- .../sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala | 2 ++ .../spark/sql/scripting/SqlScriptingExecutionContext.scala | 2 +- .../apache/spark/sql/scripting/SqlScriptingExecutionNode.scala | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index 207c586996fd8..ad00a5216b4c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -62,6 +62,8 @@ case class SingleStatement(parsedPlan: LogicalPlan) * @param label Label set to CompoundBody by user or UUID otherwise. * It can be None in case when CompoundBody is not part of BeginEndCompoundBlock * for example when CompoundBody is inside loop or conditional block. + * @param isScope Flag indicating if the CompoundBody is a labeled scope. + * Scopes are used for grouping local variables and exception handlers. */ case class CompoundBody( collection: Seq[CompoundPlanStatement], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index 5a2ef62e3bb7d..3e7a5b05c9545 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -63,7 +63,7 @@ class SqlScriptingExecutionFrame( } def enterScope(label: String): Unit = { - scopes.addOne(new SqlScriptingExecutionScope(label)) + scopes.append(new SqlScriptingExecutionScope(label)) } def exitScope(label: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 2d50d37e2cb83..74b7b7bbe43a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -181,7 +181,8 @@ class NoOpStatementExec extends LeafStatementExec { * @param label * Label set by user to CompoundBody or None otherwise. * @param isScope - * Flag that indicates whether Compound Body is scope or not. + * Flag indicating if the CompoundBody is a labeled scope. + * Scopes are used for grouping local variables and exception handlers. * @param context * SqlScriptingExecutionContext keeps the execution state of current script. */ From 4c3c142f0b61e5e98ec7c318822f55a12af03d46 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Fri, 3 Jan 2025 18:06:49 +0100 Subject: [PATCH 2/5] Replace noop with collect --- .../scala/org/apache/spark/sql/SparkSession.scala | 10 +++------- .../spark/sql/scripting/SqlScriptingExecution.scala | 13 ++++++++----- .../scripting/SqlScriptingExecutionContext.scala | 1 + 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 878fdc8e267a5..ee3b19c487cab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -449,14 +449,10 @@ class SparkSession private( var result: Option[Seq[Row]] = None while (sse.hasNext) { + val df = sse.next() sse.withErrorHandling { - val df = sse.next() - if (sse.hasNext) { - df.write.format("noop").mode("overwrite").save() - } else { - // Collect results from the last DataFrame. - result = Some(df.collect().toSeq) - } + // Collect results from the current DataFrame. + result = Some(df.collect().toSeq) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 71b44cbbd0704..3c1496269938c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -50,24 +50,27 @@ class SqlScriptingExecution( ctx } - private var current: Option[DataFrame] = getNextResult + private var current: Option[DataFrame] = None - override def hasNext: Boolean = current.isDefined + override def hasNext: Boolean = { + current = getNextResult + current.isDefined + } override def next(): DataFrame = { current match { case None => throw SparkException.internalError("No more elements to iterate through.") - case Some(result) => - current = getNextResult - result + case Some(result) => result } } /** Helper method to iterate get next statements from the first available frame. */ private def getNextStatement: Option[CompoundStatementExec] = { + // Remove frames that are already executed. while (context.frames.nonEmpty && !context.frames.last.hasNext) { context.frames.remove(context.frames.size - 1) } + // If there are still frames available, get the next statement. if (context.frames.nonEmpty) { return Some(context.frames.last.next()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala index 3e7a5b05c9545..94462ab828f75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionContext.scala @@ -76,6 +76,7 @@ class SqlScriptingExecutionFrame( scopes.remove(scopes.length - 1) } + // Remove the scope with the given label. if (scopes.nonEmpty) { scopes.remove(scopes.length - 1) } From f6741a5806f1aed2e579c4282703fc34756f3643 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 8 Jan 2025 12:49:01 +0100 Subject: [PATCH 3/5] Replace addOne with append --- .../org/apache/spark/sql/scripting/SqlScriptingExecution.scala | 2 +- .../spark/sql/scripting/SqlScriptingInterpreterSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 3c1496269938c..e214a645b3c29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -42,7 +42,7 @@ class SqlScriptingExecution( val ctx = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(sqlScript, args, ctx) // Add frame which represents SQL Script to the context. - ctx.frames.addOne(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) + ctx.frames.append(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) // Enter the scope of the top level compound. // We don't need to exit this scope explicitly as it will be done automatically // when the frame is removed during iteration. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 20997504b15eb..c7439a8934d73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -49,7 +49,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { // Initialize context so scopes can be entered correctly. val context = new SqlScriptingExecutionContext() val executionPlan = interpreter.buildExecutionPlan(compoundBody, args, context) - context.frames.addOne(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) + context.frames.append(new SqlScriptingExecutionFrame(executionPlan.getTreeIterator)) executionPlan.enterScope() executionPlan.getTreeIterator.flatMap { From a98302ee4b88877ac3df64d528479662ca2a67fb Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Wed, 8 Jan 2025 15:02:03 +0100 Subject: [PATCH 4/5] Add comments to explain SSE iteration logic and contract --- .../org/apache/spark/sql/SparkSession.scala | 5 +++++ .../sql/scripting/SqlScriptingExecution.scala | 21 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index ee3b19c487cab..19e37d686bce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -448,6 +448,11 @@ class SparkSession private( val sse = new SqlScriptingExecution(script, this, args) var result: Option[Seq[Row]] = None + // We must call hasNext, next, and collect in this order because: + // 1. sse.hasNext is not idempotent - it advances the script execution + // 2. sse.next() returns the next result DataFrame + // 3. We must collect results immediately to maintain execution order + // This ensures we respect the contract of SqlScriptingExecution API. while (sse.hasNext) { val df = sse.next() sse.withErrorHandling { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index e214a645b3c29..190daed2b6387 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} /** * SQL scripting executor - executes script and returns result statements. * This supports returning multiple result statements from a single script. + * The caller of the SqlScriptingExecution API must adhere to the contract of executing + * the returned statement before continuing iteration. Executing the statement needs to be done + * inside withErrorHandling block. * * @param sqlScript CompoundBody which need to be executed. * @param session Spark session that SQL script is executed within. @@ -52,11 +55,29 @@ class SqlScriptingExecution( private var current: Option[DataFrame] = None + /** + * Advances through the script and executes statements until a result statement or + * end of script is encountered. + * + * To know if there is result statement available, the hasNext needs to advance through + * script and execute statements until the result statement or end of script is encountered. + * For that reason hasNext must be called only once before each `next()` invocation, + * and the returned result must be executed before subsequent calls. Multiple calls without + * executing the intermediate results will lead to incorrect behavior. + * + * @return True if a result statement is available, False otherwise. + */ override def hasNext: Boolean = { current = getNextResult current.isDefined } + /** + * Returns the next result statement from the script. + * Multiple consecutive calls without calling `hasNext()` would return the same result statement. + * + * @return The next result statement. + */ override def next(): DataFrame = { current match { case None => throw SparkException.internalError("No more elements to iterate through.") From dca7fd9738b9e19f5b74871833139faa2ee05d43 Mon Sep 17 00:00:00 2001 From: Milan Dankovic Date: Thu, 9 Jan 2025 13:58:06 +0100 Subject: [PATCH 5/5] Remove iterator interface from SqlScriptingExecution --- .../org/apache/spark/sql/SparkSession.scala | 14 +++--- .../sql/scripting/SqlScriptingExecution.scala | 48 +++++-------------- .../SqlScriptingExecutionSuite.scala | 12 ++++- 3 files changed, 31 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 19e37d686bce4..3b36f6b59cb38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -448,17 +448,17 @@ class SparkSession private( val sse = new SqlScriptingExecution(script, this, args) var result: Option[Seq[Row]] = None - // We must call hasNext, next, and collect in this order because: - // 1. sse.hasNext is not idempotent - it advances the script execution - // 2. sse.next() returns the next result DataFrame - // 3. We must collect results immediately to maintain execution order + // We must execute returned df before calling sse.getNextResult again because sse.hasNext + // advances the script execution and executes all statements until the next result. We must + // collect results immediately to maintain execution order. // This ensures we respect the contract of SqlScriptingExecution API. - while (sse.hasNext) { - val df = sse.next() + var df: Option[DataFrame] = sse.getNextResult + while (df.isDefined) { sse.withErrorHandling { // Collect results from the current DataFrame. - result = Some(df.collect().toSeq) + result = Some(df.get.collect().toSeq) } + df = sse.getNextResult } if (result.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala index 190daed2b6387..2b15a6c55fa97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.scripting -import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, CompoundBody} class SqlScriptingExecution( sqlScript: CompoundBody, session: SparkSession, - args: Map[String, Expression]) extends Iterator[DataFrame] { + args: Map[String, Expression]) { private val interpreter = SqlScriptingInterpreter(session) @@ -53,37 +52,6 @@ class SqlScriptingExecution( ctx } - private var current: Option[DataFrame] = None - - /** - * Advances through the script and executes statements until a result statement or - * end of script is encountered. - * - * To know if there is result statement available, the hasNext needs to advance through - * script and execute statements until the result statement or end of script is encountered. - * For that reason hasNext must be called only once before each `next()` invocation, - * and the returned result must be executed before subsequent calls. Multiple calls without - * executing the intermediate results will lead to incorrect behavior. - * - * @return True if a result statement is available, False otherwise. - */ - override def hasNext: Boolean = { - current = getNextResult - current.isDefined - } - - /** - * Returns the next result statement from the script. - * Multiple consecutive calls without calling `hasNext()` would return the same result statement. - * - * @return The next result statement. - */ - override def next(): DataFrame = { - current match { - case None => throw SparkException.internalError("No more elements to iterate through.") - case Some(result) => result - } - } /** Helper method to iterate get next statements from the first available frame. */ private def getNextStatement: Option[CompoundStatementExec] = { @@ -98,8 +66,18 @@ class SqlScriptingExecution( None } - /** Helper method to iterate through statements until next result statement is encountered. */ - private def getNextResult: Option[DataFrame] = { + /** + * Advances through the script and executes statements until a result statement or + * end of script is encountered. + * + * To know if there is result statement available, the method has to advance through script and + * execute statements until the result statement or end of script is encountered. For that reason + * the returned result must be executed before subsequent calls. Multiple calls without executing + * the intermediate results will lead to incorrect behavior. + * + * @return Result DataFrame if it is available, otherwise None. + */ + def getNextResult: Option[DataFrame] = { var currentStatement = getNextStatement // While we don't have a result statement, execute the statements. while (currentStatement.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index bbeae942f9fe7..5b5285ea13275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.scripting +import scala.collection.mutable.ListBuffer + import org.apache.spark.SparkConf import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.Expression @@ -43,7 +45,15 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { args: Map[String, Expression] = Map.empty): Seq[Array[Row]] = { val compoundBody = spark.sessionState.sqlParser.parsePlan(sqlText).asInstanceOf[CompoundBody] val sse = new SqlScriptingExecution(compoundBody, spark, args) - sse.map { df => df.collect() }.toList + val result: ListBuffer[Array[Row]] = ListBuffer.empty + + var df = sse.getNextResult + while (df.isDefined) { + // Collect results from the current DataFrame. + result.append(df.get.collect()) + df = sse.getNextResult + } + result.toSeq } private def verifySqlScriptResult(sqlText: String, expected: Seq[Seq[Row]]): Unit = {