Skip to content

Commit

Permalink
Cherry pick fix for SPARK-27991 (#37)
Browse files Browse the repository at this point in the history
* Cherry pick fix for SPARK-27991

* Added missing function answerFetchBlocks

* Modifying constructor calls of ShuffleBlockFetcherIterator to accommodate new argument

* More changes to fix the cherrypick
  • Loading branch information
s-pedamallu authored Jun 3, 2021
1 parent 0ae05fa commit 7fe6aa6
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public class NettyUtils {
private static final PooledByteBufAllocator[] _sharedPooledByteBufAllocator =
new PooledByteBufAllocator[2];

public static long freeDirectMemory() {
return PlatformDependent.maxDirectMemory() - PlatformDependent.usedDirectMemory();
}

/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new DefaultThreadFactory(threadPoolPrefix, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.control.NonFatal

import io.netty.util.internal.PlatformDependent
import org.json4s.DefaultFormats

import org.apache.spark._
Expand Down Expand Up @@ -89,6 +90,13 @@ private[spark] class CoarseGrainedExecutorBackend(

logInfo("Connecting to driver: " + driverUrl)
try {
if (PlatformDependent.directBufferPreferred() &&
PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) {
throw new SparkException(s"Netty direct memory should at least be bigger than " +
s"'${MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key}', but got " +
s"${PlatformDependent.maxDirectMemory()} bytes < " +
s"${env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)}")
}
_resources = parseOrFindResources(resourcesFileOpt)
} catch {
case NonFatal(e) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,15 @@ package object config {
.intConf
.createWithDefault(3)

private[spark] val SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM =
ConfigBuilder("spark.shuffle.maxAttemptsOnNettyOOM")
.doc("The max attempts of a shuffle block would retry on Netty OOM issue before throwing " +
"the shuffle fetch failure.")
.version("3.2.0")
.internal()
.intConf
.createWithDefault(10)

private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS =
ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress")
.doc("This configuration limits the number of remote blocks being fetched per reduce task " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
readMetrics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ package org.apache.spark.storage
import java.io.{InputStream, IOException}
import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue}
import scala.util.{Failure, Success}

import io.netty.util.internal.OutOfDirectMemoryError
import org.apache.commons.io.IOUtils

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle._
import org.apache.spark.network.util.TransportConf
import org.apache.spark.network.util.{NettyUtils, TransportConf}
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}

Expand Down Expand Up @@ -61,6 +63,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
* for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before
* throwing the fetch failure.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
* @param shuffleMetrics used to report shuffle metrics.
* @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server
Expand All @@ -76,7 +80,8 @@ final class ShuffleBlockFetcherIterator(
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
val maxReqSizeShuffleToMem: Long,
maxAttemptsOnNettyOOM: Int,
detectCorrupt: Boolean,
detectCorruptUseExtraMemory: Boolean,
shuffleMetrics: ShuffleReadMetricsReporter,
Expand Down Expand Up @@ -146,6 +151,12 @@ final class ShuffleBlockFetcherIterator(
/** Current number of blocks in flight per host:port */
private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()

/**
* Count the retry times for the blocks due to Netty OOM. The block will stop retry if
* retry times has exceeded the [[maxAttemptsOnNettyOOM]].
*/
private[this] val blockOOMRetryCounts = new HashMap[String, Int]

/**
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
* at most once for those corrupted blocks.
Expand Down Expand Up @@ -245,9 +256,21 @@ final class ShuffleBlockFetcherIterator(
case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
val deferredBlocks = new ArrayBuffer[String]()
val blockIds = req.blocks.map(_.blockId.toString)
val address = req.address

@inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
val blocks = deferredBlocks.map { blockId =>
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
}
results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq)))
deferredBlocks.clear()
}
}

val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
Expand All @@ -258,17 +281,57 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
blockOOMRetryCounts.remove(blockId)
results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
enqueueDeferredFetchRequestIfNecessary()
}
}
logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}")
}

override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
ShuffleBlockFetcherIterator.this.synchronized {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
e match {
// SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among
// tasks) to true as early as possible. The pending fetch requests won't be sent
// afterwards until the flag is set to false on:
// 1) the Netty free memory >= maxReqSizeShuffleToMem
// - we'll check this whenever there's a fetch request succeeds.
// 2) the number of in-flight requests becomes 0
// - we'll check this in `fetchUpToMaxBytes` whenever it's invoked.
// Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag
// only takes effect for the shuffle due to the implementation simplicity concern.
// And we'll buffer the consecutive block failures caused by the OOM error until there's
// no remaining blocks in the current request. Then, we'll package these blocks into
// a same fetch request for the retry later. In this way, instead of creating the fetch
// request per block, it would help reduce the concurrent connections and data loads
// pressure at remote server.
// Note that catching OOM and do something based on it is only a workaround for
// handling the Netty OOM issue, which is not the best way towards memory management.
// We can get rid of it when we find a way to manage Netty's memory precisely.
case _: OutOfDirectMemoryError
if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM =>
if (!isZombie) {
val failureTimes = blockOOMRetryCounts(blockId)
blockOOMRetryCounts(blockId) += 1
if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
// The fetcher can fail remaining blocks in batch for the same error. So we only
// log the warning once to avoid flooding the logs.
logInfo(s"Block $blockId has failed $failureTimes times " +
s"due to Netty OOM, will retry")
}
remainingBlocks -= blockId
deferredBlocks += blockId
enqueueDeferredFetchRequestIfNecessary()
}

case _ =>
results.put(FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
}
}
}
}

Expand Down Expand Up @@ -611,6 +674,7 @@ final class ShuffleBlockFetcherIterator(
}
if (isNetworkReqDone) {
reqsInFlight -= 1
resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem)
logDebug("Number of requests in flight " + reqsInFlight)
}

Expand Down Expand Up @@ -682,7 +746,25 @@ final class ShuffleBlockFetcherIterator(
}

case FailureFetchResult(blockId, mapIndex, address, e) =>
throwFetchFailedException(blockId, mapIndex, address, e)
var errorMsg: String = null
if (e.isInstanceOf[OutOfDirectMemoryError]) {
errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " +
s"retries due to Netty OOM"
logError(errorMsg)
}
throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg))

case DeferFetchRequestResult(request) =>
val address = request.address
numBlocksInFlightPerAddress(address) =
numBlocksInFlightPerAddress(address) - request.blocks.size
bytesInFlight -= request.size
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
val defReqQueue =
deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
result = null
}

// Send fetch requests up to maxBytesInFlight
Expand All @@ -697,7 +779,8 @@ final class ShuffleBlockFetcherIterator(
currentResult.blockId,
currentResult.mapIndex,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
detectCorrupt && streamCompressedOrEncrypted,
currentResult.isNetworkReqDone))
}

def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
Expand All @@ -706,6 +789,15 @@ final class ShuffleBlockFetcherIterator(
}

private def fetchUpToMaxBytes(): Unit = {
if (isNettyOOMOnShuffle.get()) {
if (reqsInFlight > 0) {
// Return immediately if Netty is still OOMed and there're ongoing fetch requests
return
} else {
resetNettyOOMFlagIfPossible(0)
}
}

// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
// immediately, defer the request until the next time it can be processed.

Expand Down Expand Up @@ -764,12 +856,14 @@ final class ShuffleBlockFetcherIterator(
blockId: BlockId,
mapIndex: Int,
address: BlockManagerId,
e: Throwable) = {
e: Throwable,
message: Option[String] = None) = {
val msg = message.getOrElse(e.getMessage)
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e)
throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, msg, e)
case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) =>
throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e)
throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, msg, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
Expand All @@ -788,7 +882,8 @@ private class BufferReleasingInputStream(
private val blockId: BlockId,
private val mapIndex: Int,
private val address: BlockManagerId,
private val detectCorruption: Boolean)
private val detectCorruption: Boolean,
private val isNetworkReqDone: Boolean)
extends InputStream {
private[this] var closed = false

Expand All @@ -804,9 +899,16 @@ private class BufferReleasingInputStream(

override def close(): Unit = {
if (!closed) {
delegate.close()
iterator.releaseCurrentResultBuffer()
closed = true
try {
delegate.close()
iterator.releaseCurrentResultBuffer()
} finally {
// Unset the flag when a remote request finished and free memory is fairly enough.
if (isNetworkReqDone) {
ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible(iterator.maxReqSizeShuffleToMem)
}
closed = true
}
}
}

Expand Down Expand Up @@ -873,6 +975,20 @@ private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterat
private[storage]
object ShuffleBlockFetcherIterator {

/**
* A flag which indicates whether the Netty OOM error has raised during shuffle.
* If true, unless there's no in-flight fetch requests, all the pending shuffle
* fetch requests will be deferred until the flag is unset (whenever there's a
* complete fetch request).
*/
val isNettyOOMOnShuffle = new AtomicBoolean(false)

def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = {
if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) {
isNettyOOMOnShuffle.compareAndSet(true, false)
}
}

/**
* This function is used to merged blocks when doBatchFetch is true. Blocks which have the
* same `mapId` can be merged into one block batch. The block batch is specified by a range
Expand Down Expand Up @@ -977,10 +1093,7 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block.
*/
private[storage] sealed trait FetchResult {
val blockId: BlockId
val address: BlockManagerId
}
private[storage] sealed trait FetchResult

/**
* Result of a fetch from a remote block successfully.
Expand Down Expand Up @@ -1016,4 +1129,10 @@ object ShuffleBlockFetcherIterator {
address: BlockManagerId,
e: Throwable)
extends FetchResult

/**
* Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM
*/
private[storage]
case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
}
Loading

0 comments on commit 7fe6aa6

Please sign in to comment.