Skip to content

Commit

Permalink
[SPARK-50768][CORE] Introduce TaskContext.createResourceUninterruptib…
Browse files Browse the repository at this point in the history
…ly to avoid stream leak by task interruption

### What changes were proposed in this pull request?

This PR fixes the potential stream leak issue by introduing `TaskContext.createResourceUninterruptibly`.

When a task is using `TaskContext.createResourceUninterruptibly` to create the resource, the task would be marked as uninterruptible. Thus, any interruption request during the call to `TaskContext.createResourceUninterruptibly` would be delayed until the creation finishes.

This PR introduces an new lock contention between `Task.kill` and `TaskContext.createResourceUninterruptibly`. But I think it is acceptable given that both are not on the hot-path.

(I will submmit a followup to apply `TaskContext.createResourceUninterruptibly` in the codebase if this PR is approved by the community.)

### Why are the changes needed?

We had #48483 tried to fix the potential stream leak issue by task interruption. It mitigates the issue by using

```scala
def tryInitializeResource[R <: Closeable, T](createResource: => R)(initialize: R => T): T = {
  val resource = createResource
  try {
    initialize(resource)
  } catch {
    case e: Throwable =>
      resource.close()
      throw e
  }
}
```
But this utility function has an issue that `resource.close()` would leak open resouces if `initialize(resource)` also created some resources internally, especially when `initialize(resource)` is interrupted (See the example of `InterruptionSensitiveInputStream` in the test).

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Added a unit test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49413 from Ngone51/dev-interrupt.

Authored-by: Yi Wu <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
  • Loading branch information
Ngone51 authored and LuciferYang committed Jan 15, 2025
1 parent f223b8d commit 6f3b778
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 7 deletions.
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.io.Closeable
import java.util.{Properties, TimerTask}
import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit}

Expand Down Expand Up @@ -273,6 +274,18 @@ class BarrierTaskContext private[spark] (
}

override private[spark] def getLocalProperties: Properties = taskContext.getLocalProperties

override private[spark] def interruptible(): Boolean = taskContext.interruptible()

override private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String)
: Unit = {
taskContext.pendingInterrupt(threadToInterrupt, reason)
}

override private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T)
: T = {
taskContext.createResourceUninterruptibly(resourceBuilder)
}
}

@Experimental
Expand Down
22 changes: 21 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import java.io.Serializable
import java.io.Closeable
import java.util.Properties

import org.apache.spark.annotation.{DeveloperApi, Evolving, Since}
Expand Down Expand Up @@ -305,4 +305,24 @@ abstract class TaskContext extends Serializable {

/** Gets local properties set upstream in the driver. */
private[spark] def getLocalProperties: Properties

/** Whether the current task is allowed to interrupt. */
private[spark] def interruptible(): Boolean

/**
* Pending the interruption request until the task is able to
* interrupt after creating the resource uninterruptibly.
*/
private[spark] def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String): Unit

/**
* Creating a closeable resource uninterruptibly. A task is not allowed to interrupt in this
* state until the resource creation finishes. E.g.,
* {{{
* val linesReader = TaskContext.get().createResourceUninterruptibly {
* new HadoopFileLinesReader(file, parser.options.lineSeparatorInRead, conf)
* }
* }}}
*/
private[spark] def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T): T
}
43 changes: 43 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.io.Closeable
import java.util.{Properties, Stack}
import javax.annotation.concurrent.GuardedBy

Expand Down Expand Up @@ -82,6 +83,13 @@ private[spark] class TaskContextImpl(
// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None

// The pending interruption request, which is blocked by uninterruptible resource creation.
// Should be protected by `TaskContext.synchronized`.
private var pendingInterruptRequest: Option[(Option[Thread], String)] = None

// Whether this task is able to be interrupted. Should be protected by `TaskContext.synchronized`.
private var _interruptible = true

// Whether the task has completed.
private var completed: Boolean = false

Expand Down Expand Up @@ -296,4 +304,39 @@ private[spark] class TaskContextImpl(
private[spark] override def fetchFailed: Option[FetchFailedException] = _fetchFailedException

private[spark] override def getLocalProperties: Properties = localProperties


override def interruptible(): Boolean = TaskContext.synchronized(_interruptible)

override def pendingInterrupt(threadToInterrupt: Option[Thread], reason: String): Unit = {
TaskContext.synchronized {
pendingInterruptRequest = Some((threadToInterrupt, reason))
}
}

def createResourceUninterruptibly[T <: Closeable](resourceBuilder: => T): T = {

@inline def interruptIfRequired(): Unit = {
pendingInterruptRequest.foreach { case (threadToInterrupt, reason) =>
markInterrupted(reason)
threadToInterrupt.foreach(_.interrupt())
}
killTaskIfInterrupted()
}

TaskContext.synchronized {
interruptIfRequired()
_interruptible = false
}
try {
val resource = resourceBuilder
addTaskCompletionListener[Unit](_ => resource.close())
resource
} finally {
TaskContext.synchronized {
_interruptible = true
interruptIfRequired()
}
}
}
}
20 changes: 15 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Properties

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.internal.plugin.PluginContainer
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
Expand Down Expand Up @@ -70,7 +71,7 @@ private[spark] abstract class Task[T](
val jobId: Option[Int] = None,
val appId: Option[String] = None,
val appAttemptId: Option[String] = None,
val isBarrier: Boolean = false) extends Serializable {
val isBarrier: Boolean = false) extends Serializable with Logging {

@transient lazy val metrics: TaskMetrics =
SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics))
Expand Down Expand Up @@ -231,10 +232,19 @@ private[spark] abstract class Task[T](
require(reason != null)
_reasonIfKilled = reason
if (context != null) {
context.markInterrupted(reason)
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
TaskContext.synchronized {
if (context.interruptible()) {
context.markInterrupted(reason)
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
} else {
logInfo(log"Task ${MDC(LogKeys.TASK_ID, context.taskAttemptId())} " +
log"is currently not interruptible. ")
val threadToInterrupt = if (interruptThread) Option(taskThread) else None
context.pendingInterrupt(threadToInterrupt, reason)
}
}
}
}
}
139 changes: 138 additions & 1 deletion core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.io.{File, FileOutputStream, InputStream, ObjectOutputStream}
import java.util.concurrent.{Semaphore, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

Expand All @@ -35,7 +36,7 @@ import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Deploy._
import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerExecutorRemoved, SparkListenerJobEnd, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* Test suite for cancelling running jobs. We run the cancellation tasks for single job action
Expand Down Expand Up @@ -712,6 +713,142 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
assert(executionOfInterruptibleCounter.get() < numElements)
}

Seq(true, false).foreach { interruptible =>

val (hint1, hint2) = if (interruptible) {
(" not", "")
} else {
("", " not")
}

val testName = s"SPARK-50768:$hint1 use TaskContext.createResourceUninterruptibly " +
s"would$hint2 cause stream leak on task interruption"

test(testName) {
import org.apache.spark.JobCancellationSuite._
withTempDir { dir =>

// `InterruptionSensitiveInputStream` is designed to easily leak the underlying
// stream when task thread interruption happens during its initialization, as
// the reference to the underlying stream is intentionally not available to
// `InterruptionSensitiveInputStream` at that point.
class InterruptionSensitiveInputStream(fileHint: String) extends InputStream {
private var underlying: InputStream = _

def initialize(): InputStream = {
val in: InputStream = new InputStream {

open()

private def dumpFile(typeName: String): Unit = {
var fileOut: FileOutputStream = null
var objOut: ObjectOutputStream = null
try {
val file = new File(dir, s"$typeName.$fileHint")
fileOut = new FileOutputStream(file)
objOut = new ObjectOutputStream(fileOut)
objOut.writeBoolean(true)
objOut.flush()
} finally {
if (fileOut != null) {
fileOut.close()
}
if (objOut != null) {
objOut.close()
}
}

}

private def open(): Unit = {
dumpFile("open")
}

override def close(): Unit = {
dumpFile("close")
}

override def read(): Int = -1
}

// Leave some time for the task to be interrupted during the
// creation of `InterruptionSensitiveInputStream`.
Thread.sleep(10000)

underlying = in
underlying
}

override def read(): Int = -1

override def close(): Unit = {
if (underlying != null) {
underlying.close()
}
}
}

def createStream(fileHint: String): Unit = {
if (interruptible) {
Utils.tryInitializeResource {
new InterruptionSensitiveInputStream(fileHint)
} {
_.initialize()
}
} else {
TaskContext.get().createResourceUninterruptibly[java.io.InputStream] {
Utils.tryInitializeResource {
new InterruptionSensitiveInputStream(fileHint)
} {
_.initialize()
}
}
}
}

sc = new SparkContext("local[2]", "test interrupt streams")

sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
// Sleep some time to ensure task has started
Thread.sleep(2000)
taskStartedSemaphore.release()
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
if (taskEnd.reason.isInstanceOf[TaskKilled]) {
taskCancelledSemaphore.release()
}
}
})

sc.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "true")

val fileHint = if (interruptible) "interruptible" else "uninterruptible"
val future = sc.parallelize(1 to 100, 1).mapPartitions { _ =>
createStream(fileHint)
Iterator.single(1)
}.collectAsync()

taskStartedSemaphore.acquire()
future.cancel()
taskCancelledSemaphore.acquire()

val fileOpen = new File(dir, s"open.$fileHint")
val fileClose = new File(dir, s"close.$fileHint")
assert(fileOpen.exists())

if (interruptible) {
// The underlying stream leaks when the stream creation is interruptible.
assert(!fileClose.exists())
} else {
// The underlying stream won't leak when the stream creation is uninterruptible.
assert(fileClose.exists())
}
}
}
}

def testCount(): Unit = {
// Cancel before launching any tasks
{
Expand Down
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ object MimaExcludes {

// SPARK-50112: Moving avro files from connector to sql/core
ProblemFilters.exclude[Problem]("org.apache.spark.sql.avro.*"),

// SPARK-50768: Introduce TaskContext.createResourceUninterruptibly to avoid stream leak by task interruption
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.interruptible"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.pendingInterrupt"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.createResourceUninterruptibly"),
) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++
loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++
loggingExcludes("org.apache.spark.sql.SparkSession#Builder")
Expand Down

0 comments on commit 6f3b778

Please sign in to comment.