From 7fe6aa6b24304f5160b228fde92aa5a4102ad142 Mon Sep 17 00:00:00 2001 From: Shashank Pedamallu Date: Thu, 3 Jun 2021 10:33:40 -0700 Subject: [PATCH] Cherry pick fix for SPARK-27991 (#37) * Cherry pick fix for SPARK-27991 * Added missing function answerFetchBlocks * Modifying constructor calls of ShuffleBlockFetcherIterator to accommodate new argument * More changes to fix the cherrypick --- .../apache/spark/network/util/NettyUtils.java | 4 + .../CoarseGrainedExecutorBackend.scala | 8 + .../spark/internal/config/package.scala | 9 + .../shuffle/BlockStoreShuffleReader.scala | 1 + .../storage/ShuffleBlockFetcherIterator.scala | 153 +++++++++++-- .../ShuffleBlockFetcherIteratorSuite.scala | 203 +++++++++++++++++- 6 files changed, 359 insertions(+), 19 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 423cc0c70ea02..58fded6f984c0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -52,6 +52,10 @@ public class NettyUtils { private static final PooledByteBufAllocator[] _sharedPooledByteBufAllocator = new PooledByteBufAllocator[2]; + public static long freeDirectMemory() { + return PlatformDependent.maxDirectMemory() - PlatformDependent.usedDirectMemory(); + } + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ public static ThreadFactory createThreadFactory(String threadPoolPrefix) { return new DefaultThreadFactory(threadPoolPrefix, true); diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 6a1fd57873c3a..639aed4cdd57c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal +import io.netty.util.internal.PlatformDependent import org.json4s.DefaultFormats import org.apache.spark._ @@ -89,6 +90,13 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) try { + if (PlatformDependent.directBufferPreferred() && + PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) { + throw new SparkException(s"Netty direct memory should at least be bigger than " + + s"'${MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key}', but got " + + s"${PlatformDependent.maxDirectMemory()} bytes < " + + s"${env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)}") + } _resources = parseOrFindResources(resourcesFileOpt) } catch { case NonFatal(e) => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 3daa9f5362d9d..7ff7e47d39289 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1139,6 +1139,15 @@ package object config { .intConf .createWithDefault(3) +private[spark] val SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM = + ConfigBuilder("spark.shuffle.maxAttemptsOnNettyOOM") + .doc("The max attempts of a shuffle block would retry on Netty OOM issue before throwing " + + "the shuffle fetch failure.") + .version("3.2.0") + .internal() + .intConf + .createWithDefault(10) + private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") .doc("This configuration limits the number of remote blocks being fetched per reduce task " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 30c752960d5da..de86efbb52271 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -79,6 +79,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), readMetrics, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fa4e46590aa5e..0c50dcf006cb2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,19 +20,21 @@ package org.apache.spark.storage import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue} import scala.util.{Failure, Success} +import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ -import org.apache.spark.network.util.TransportConf +import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} @@ -61,6 +63,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. + * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before + * throwing the fetch failure. * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param shuffleMetrics used to report shuffle metrics. * @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server @@ -76,7 +80,8 @@ final class ShuffleBlockFetcherIterator( maxBytesInFlight: Long, maxReqsInFlight: Int, maxBlocksInFlightPerAddress: Int, - maxReqSizeShuffleToMem: Long, + val maxReqSizeShuffleToMem: Long, + maxAttemptsOnNettyOOM: Int, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, @@ -146,6 +151,12 @@ final class ShuffleBlockFetcherIterator( /** Current number of blocks in flight per host:port */ private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + /** + * Count the retry times for the blocks due to Netty OOM. The block will stop retry if + * retry times has exceeded the [[maxAttemptsOnNettyOOM]]. + */ + private[this] val blockOOMRetryCounts = new HashMap[String, Int] + /** * The blocks that can't be decompressed successfully, it is used to guarantee that we retry * at most once for those corrupted blocks. @@ -245,9 +256,21 @@ final class ShuffleBlockFetcherIterator( case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) }.toMap val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val deferredBlocks = new ArrayBuffer[String]() val blockIds = req.blocks.map(_.blockId.toString) val address = req.address + @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { + if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { + val blocks = deferredBlocks.map { blockId => + val (size, mapIndex) = infoMap(blockId) + FetchBlockInfo(BlockId(blockId), size, mapIndex) + } + results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq))) + deferredBlocks.clear() + } + } + val blockFetchingListener = new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie, @@ -258,17 +281,57 @@ final class ShuffleBlockFetcherIterator( // This needs to be released after use. buf.retain() remainingBlocks -= blockId + blockOOMRetryCounts.remove(blockId) results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2, address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) + enqueueDeferredFetchRequestIfNecessary() } } logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) + ShuffleBlockFetcherIterator.this.synchronized { + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + e match { + // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among + // tasks) to true as early as possible. The pending fetch requests won't be sent + // afterwards until the flag is set to false on: + // 1) the Netty free memory >= maxReqSizeShuffleToMem + // - we'll check this whenever there's a fetch request succeeds. + // 2) the number of in-flight requests becomes 0 + // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. + // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag + // only takes effect for the shuffle due to the implementation simplicity concern. + // And we'll buffer the consecutive block failures caused by the OOM error until there's + // no remaining blocks in the current request. Then, we'll package these blocks into + // a same fetch request for the retry later. In this way, instead of creating the fetch + // request per block, it would help reduce the concurrent connections and data loads + // pressure at remote server. + // Note that catching OOM and do something based on it is only a workaround for + // handling the Netty OOM issue, which is not the best way towards memory management. + // We can get rid of it when we find a way to manage Netty's memory precisely. + case _: OutOfDirectMemoryError + if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => + if (!isZombie) { + val failureTimes = blockOOMRetryCounts(blockId) + blockOOMRetryCounts(blockId) += 1 + if (isNettyOOMOnShuffle.compareAndSet(false, true)) { + // The fetcher can fail remaining blocks in batch for the same error. So we only + // log the warning once to avoid flooding the logs. + logInfo(s"Block $blockId has failed $failureTimes times " + + s"due to Netty OOM, will retry") + } + remainingBlocks -= blockId + deferredBlocks += blockId + enqueueDeferredFetchRequestIfNecessary() + } + + case _ => + results.put(FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) + } + } } } @@ -611,6 +674,7 @@ final class ShuffleBlockFetcherIterator( } if (isNetworkReqDone) { reqsInFlight -= 1 + resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) logDebug("Number of requests in flight " + reqsInFlight) } @@ -682,7 +746,25 @@ final class ShuffleBlockFetcherIterator( } case FailureFetchResult(blockId, mapIndex, address, e) => - throwFetchFailedException(blockId, mapIndex, address, e) + var errorMsg: String = null + if (e.isInstanceOf[OutOfDirectMemoryError]) { + errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + + s"retries due to Netty OOM" + logError(errorMsg) + } + throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) + + case DeferFetchRequestResult(request) => + val address = request.address + numBlocksInFlightPerAddress(address) = + numBlocksInFlightPerAddress(address) - request.blocks.size + bytesInFlight -= request.size + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + val defReqQueue = + deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + result = null } // Send fetch requests up to maxBytesInFlight @@ -697,7 +779,8 @@ final class ShuffleBlockFetcherIterator( currentResult.blockId, currentResult.mapIndex, currentResult.address, - detectCorrupt && streamCompressedOrEncrypted)) + detectCorrupt && streamCompressedOrEncrypted, + currentResult.isNetworkReqDone)) } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -706,6 +789,15 @@ final class ShuffleBlockFetcherIterator( } private def fetchUpToMaxBytes(): Unit = { + if (isNettyOOMOnShuffle.get()) { + if (reqsInFlight > 0) { + // Return immediately if Netty is still OOMed and there're ongoing fetch requests + return + } else { + resetNettyOOMFlagIfPossible(0) + } + } + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host // immediately, defer the request until the next time it can be processed. @@ -764,12 +856,14 @@ final class ShuffleBlockFetcherIterator( blockId: BlockId, mapIndex: Int, address: BlockManagerId, - e: Throwable) = { + e: Throwable, + message: Option[String] = None) = { + val msg = message.getOrElse(e.getMessage) blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e) + throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, msg, e) case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => - throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e) + throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, msg, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) @@ -788,7 +882,8 @@ private class BufferReleasingInputStream( private val blockId: BlockId, private val mapIndex: Int, private val address: BlockManagerId, - private val detectCorruption: Boolean) + private val detectCorruption: Boolean, + private val isNetworkReqDone: Boolean) extends InputStream { private[this] var closed = false @@ -804,9 +899,16 @@ private class BufferReleasingInputStream( override def close(): Unit = { if (!closed) { - delegate.close() - iterator.releaseCurrentResultBuffer() - closed = true + try { + delegate.close() + iterator.releaseCurrentResultBuffer() + } finally { + // Unset the flag when a remote request finished and free memory is fairly enough. + if (isNetworkReqDone) { + ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible(iterator.maxReqSizeShuffleToMem) + } + closed = true + } } } @@ -873,6 +975,20 @@ private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterat private[storage] object ShuffleBlockFetcherIterator { + /** + * A flag which indicates whether the Netty OOM error has raised during shuffle. + * If true, unless there's no in-flight fetch requests, all the pending shuffle + * fetch requests will be deferred until the flag is unset (whenever there's a + * complete fetch request). + */ + val isNettyOOMOnShuffle = new AtomicBoolean(false) + + def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = { + if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) { + isNettyOOMOnShuffle.compareAndSet(true, false) + } + } + /** * This function is used to merged blocks when doBatchFetch is true. Blocks which have the * same `mapId` can be merged into one block batch. The block batch is specified by a range @@ -977,10 +1093,7 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block. */ - private[storage] sealed trait FetchResult { - val blockId: BlockId - val address: BlockManagerId - } + private[storage] sealed trait FetchResult /** * Result of a fetch from a remote block successfully. @@ -1016,4 +1129,10 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, e: Throwable) extends FetchResult + + /** + * Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM + */ + private[storage] + case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 99c43b12d6553..fcc21da1f1cf3 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -25,9 +25,11 @@ import java.util.concurrent.{CompletableFuture, Semaphore} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future +import io.netty.util.internal.OutOfDirectMemoryError import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} @@ -35,13 +37,21 @@ import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient} import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { + private var transfer: BlockTransferService = _ + + private def answerFetchBlocks(answer: Answer[Unit]): Unit = + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(answer) + + private def verifyFetchBlocksInvocationCount(expectedCount: Int): Unit = + verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), any(), any(), any()) + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) // Some of the tests are quite tricky because we are testing the cleanup behavior @@ -66,6 +76,31 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } + /** Configures `transfer` (mock [[BlockTransferService]]) which mimics the Netty OOM issue. */ + private def configureNettyOOMMockTransfer( + data: Map[BlockId, ManagedBuffer], + oomBlockIndex: Int, + throwOnce: Boolean): Unit = { + var hasThrowOOM = false + answerFetchBlocks { invocation => + val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] + val listener = invocation.getArgument[BlockFetchingListener](4) + for ((blockId, i) <- blocks.zipWithIndex) { + if (!hasThrowOOM && i == oomBlockIndex) { + hasThrowOOM = throwOnce + val ctor = classOf[OutOfDirectMemoryError] + .getDeclaredConstructor(classOf[java.lang.String]) + ctor.setAccessible(true) + listener.onBlockFetchFailure(blockId, ctor.newInstance("failed to allocate memory")) + } else if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) + } + } + } + } + private def createMockBlockManager(): BlockManager = { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-local-host", 1) @@ -123,6 +158,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } + // scalastyle:off argcount + private def createShuffleBlockIteratorWithDefaults( + blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], + taskContext: Option[TaskContext] = None, + streamWrapperLimitSize: Option[Long] = None, + blockManager: Option[BlockManager] = None, + maxBytesInFlight: Long = Long.MaxValue, + maxReqsInFlight: Int = Int.MaxValue, + maxBlocksInFlightPerAddress: Int = Int.MaxValue, + maxReqSizeShuffleToMem: Int = Int.MaxValue, + maxAttemptsOnNettyOOM: Int = 10, + detectCorrupt: Boolean = true, + detectCorruptUseExtraMemory: Boolean = true, + shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, + doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { + val tContext = taskContext.getOrElse(TaskContext.empty()) + new ShuffleBlockFetcherIterator( + tContext, + transfer, + blockManager.getOrElse(createMockBlockManager()), + blocksByAddress.toIterator, + (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + maxAttemptsOnNettyOOM, + detectCorrupt, + detectCorruptUseExtraMemory, + shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), + doBatchFetch) + } + + // scalastyle:on argcount + /** + * Convert a list of block IDs into a list of blocks with metadata, assuming all blocks have the + * same size and index. + */ + private def toBlockList( + blockIds: Traversable[BlockId], + blockSize: Long, + blockMapIndex: Int): Seq[(BlockId, Long, Int)] = { + blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq + } + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = createMockBlockManager() val localBmId = blockManager.blockManagerId @@ -177,8 +257,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, 48 * 1024 * 1024, Int.MaxValue, + Int.MaxValue, Int.MaxValue, - Int.MaxValue, + 10, true, false, metrics, @@ -251,6 +332,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, metrics, @@ -282,6 +364,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, // set maxBlocksInFlightPerAddress to Int.MaxValue Int.MaxValue, + 10, true, false, metrics, @@ -324,6 +407,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, 2, // set maxBlocksInFlightPerAddress to 2 Int.MaxValue, + 10, true, false, metrics, @@ -404,6 +488,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, metrics, @@ -461,6 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, metrics, @@ -515,6 +601,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, 2, Int.MaxValue, + 10, true, false, metrics, @@ -579,6 +666,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -646,6 +734,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -733,6 +822,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -802,6 +892,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -867,6 +958,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, true, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -927,6 +1019,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -986,6 +1079,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqsInFlight = Int.MaxValue, maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, + maxAttemptsOnNettyOOM = 10, detectCorrupt = true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -1032,6 +1126,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + 10, true, false, taskContext.taskMetrics.createTempShuffleReadMetrics(), @@ -1060,4 +1155,108 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(mergedBlockId.endReduceId === bId3.reduceId + 1) assert(mergedBlock.size === inputBlocks.map(_.size).sum) } + + test("SPARK-27991: defer shuffle fetch request (one block) on Netty OOM") { + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) + + configureNettyOOMMockTransfer(remoteBlocks, oomBlockIndex = 0, throwOnce = true) + + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (remoteBmId, toBlockList(remoteBlocks.keys, 1, 1)) + ) + + val iterator = createShuffleBlockIteratorWithDefaults( + blocksByAddress = blocksByAddress, + // set maxBlocksInFlightPerAddress=1 so these 2 blocks + // would be grouped into 2 separate requests + maxBlocksInFlightPerAddress = 1) + + for (i <- 0 until remoteBlocks.size) { + assert(iterator.hasNext, + s"iterator should have ${remoteBlocks.size} elements but actually has $i elements") + val (blockId, inputStream) = iterator.next() + + // Make sure we release buffers when a wrapped input stream is closed. + val mockBuf = remoteBlocks(blockId) + verifyBufferRelease(mockBuf, inputStream) + } + + // 1st fetch request (contains 1 block) would fail due to Netty OOM + // 2nd fetch request retry the block of the 1st fetch request + // 3rd fetch request is a normal fetch + verifyFetchBlocksInvocationCount(3) + assert(!ShuffleBlockFetcherIterator.isNettyOOMOnShuffle.get()) + } + + Seq(0, 1, 2).foreach { oomBlockIndex => + test(s"SPARK-27991: defer shuffle fetch request (multiple blocks) on Netty OOM, " + + s"oomBlockIndex=$oomBlockIndex") { + val blockManager = createMockBlockManager() + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) + + configureNettyOOMMockTransfer(remoteBlocks, oomBlockIndex = oomBlockIndex, throwOnce = true) + + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (remoteBmId, toBlockList(remoteBlocks.keys, 1L, 1)) + ) + + val iterator = createShuffleBlockIteratorWithDefaults( + blocksByAddress = blocksByAddress, + // set maxBlocksInFlightPerAddress=3, so these 3 blocks would be grouped into 1 request + maxBlocksInFlightPerAddress = 3) + + for (i <- 0 until remoteBlocks.size) { + assert(iterator.hasNext, + s"iterator should have ${remoteBlocks.size} elements but actually has $i elements") + val (blockId, inputStream) = iterator.next() + + // Make sure we release buffers when a wrapped input stream is closed. + val mockBuf = remoteBlocks(blockId) + verifyBufferRelease(mockBuf, inputStream) + } + + // 1st fetch request (contains 3 blocks) would fail on the someone block due to Netty OOM + // but succeed for the remaining blocks + // 2nd fetch request retry the failed block of the 1st fetch + verifyFetchBlocksInvocationCount(2) + assert(!ShuffleBlockFetcherIterator.isNettyOOMOnShuffle.get()) + } + } + + test("SPARK-27991: block shouldn't retry endlessly on Netty OOM") { + val blockManager = createMockBlockManager() + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) + + configureNettyOOMMockTransfer(remoteBlocks, oomBlockIndex = 0, throwOnce = false) + + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (remoteBmId, toBlockList(remoteBlocks.keys, 1L, 1)) + ) + + val iterator = createShuffleBlockIteratorWithDefaults( + blocksByAddress = blocksByAddress, + // set maxBlocksInFlightPerAddress=1 so these 2 blocks + // would be grouped into 2 separate requests + maxBlocksInFlightPerAddress = 1) + + val e = intercept[FetchFailedException] { + iterator.next() + } + assert(e.getMessage.contains("fetch failed after 10 retries due to Netty OOM")) + } }