From 7f1c201d92bb3f3ddc4cff24342c75946b1ad2b2 Mon Sep 17 00:00:00 2001 From: noorall <863485501@qq.com> Date: Tue, 12 Nov 2024 15:35:55 +0800 Subject: [PATCH] [FLINK-36576][runtime] Improving amount-based data balancing distribution algorithm for DefaultVertexParallelismAndInputInfosDecider --- .../VertexInputInfoComputationUtils.java | 2 +- .../AdaptiveBatchSchedulerFactory.java | 11 + .../BatchExecutionOptionsInternal.java | 49 ++ .../adaptivebatch/BisectionSearchUtils.java | 2 +- ...VertexParallelismAndInputInfosDecider.java | 495 ++++--------- .../util/AggregatedBlockingInputInfo.java | 189 +++++ .../util/AllToAllVertexInputInfoComputer.java | 486 +++++++++++++ .../PointwiseVertexInputInfoComputer.java | 185 +++++ .../adaptivebatch/util/SubpartitionSlice.java | 135 ++++ ...xParallelismAndInputInfosDeciderUtils.java | 661 ++++++++++++++++++ ...exParallelismAndInputInfosDeciderTest.java | 98 +-- .../VertexInputInfoComputerTestUtil.java | 342 +++++++++ .../util/AggregatedBlockingInputInfoTest.java | 102 +++ .../AllToAllVertexInputInfoComputerTest.java | 400 +++++++++++ .../PointwiseVertexInputInfoComputerTest.java | 172 +++++ .../util/SubpartitionSliceTest.java | 102 +++ ...allelismAndInputInfosDeciderUtilsTest.java | 214 ++++++ 17 files changed, 3204 insertions(+), 441 deletions(-) create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchExecutionOptionsInternal.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfo.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputer.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSlice.java create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexInputInfoComputerTestUtil.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfoTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputerTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSliceTest.java create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtilsTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java index 3c8dfc50e9b40..4f9075c53ad0e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/VertexInputInfoComputationUtils.java @@ -104,7 +104,7 @@ public static Map computeVertexInputI * @param isDynamicGraph whether is dynamic graph * @return the computed {@link JobVertexInputInfo} */ - static JobVertexInputInfo computeVertexInputInfoForPointwise( + public static JobVertexInputInfo computeVertexInputInfoForPointwise( int sourceCount, int targetCount, Function numOfSubpartitionsRetriever, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java index 8c8e72059a72b..d77f18e0fa772 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java @@ -199,6 +199,17 @@ public SchedulerNG createInstance( new ScheduledExecutorServiceAdapter(futureExecutor), DefaultVertexParallelismAndInputInfosDecider.from( getDefaultMaxParallelism(jobMasterConfiguration, executionConfig), + executionPlan + .getJobConfiguration() + .get( + BatchExecutionOptionsInternal + .ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_FACTOR), + executionPlan + .getJobConfiguration() + .get( + BatchExecutionOptionsInternal + .ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_THRESHOLD) + .getBytes(), jobMasterConfiguration), jobRecoveryHandler); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchExecutionOptionsInternal.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchExecutionOptionsInternal.java new file mode 100644 index 0000000000000..78d7d706697a5 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BatchExecutionOptionsInternal.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.ConfigOption; +import org.apache.flink.configuration.MemorySize; + +import static org.apache.flink.configuration.ConfigOptions.key; + +/** Internal configuration options for the batch job execution. */ +@Internal +public class BatchExecutionOptionsInternal { + public static final ConfigOption ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_THRESHOLD = + key("$internal.execution.batch.adaptive.skewed-optimization.skewed-threshold") + .memoryType() + .defaultValue(MemorySize.ofMebiBytes(256)) + .withDescription( + "Flink will automatically reduce the ratio of the maximum to median concurrent task " + + "processing data volume to below the skewed-factor and will also achieve " + + "a more balanced data distribution, unless the maximum value is below the " + + "skewed-threshold."); + + public static final ConfigOption ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_FACTOR = + key("$internal.execution.batch.adaptive.skewed-optimization.skewed-factor") + .doubleType() + .defaultValue(4.0) + .withDescription( + "When the maximum data volume processed by a concurrent task is greater than the " + + "skewed-threshold, Flink can automatically reduce the ratio of the maximum " + + "data volume processed by a concurrent task to the median to less than the " + + "skewed-factor and will also achieve a more balanced data distribution."); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BisectionSearchUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BisectionSearchUtils.java index 329e7124fa49a..f026a5681038d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BisectionSearchUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BisectionSearchUtils.java @@ -22,7 +22,7 @@ import java.util.function.Function; /** Utility class for bisection search. */ -class BisectionSearchUtils { +public class BisectionSearchUtils { public static long findMinLegalValue( Function legalChecker, long low, long high) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java index 42c3bec4d7e8f..38ac16e467dff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDecider.java @@ -22,29 +22,25 @@ import org.apache.flink.configuration.BatchExecutionOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.MemorySize; -import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo; -import org.apache.flink.runtime.executiongraph.IndexRange; import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos; -import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.adaptivebatch.util.AllToAllVertexInputInfoComputer; +import org.apache.flink.runtime.scheduler.adaptivebatch.util.PointwiseVertexInputInfoComputer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.checkAndGetParallelism; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.getNonBroadcastInputInfos; import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -53,11 +49,12 @@ * Default implementation of {@link VertexParallelismAndInputInfosDecider}. This implementation will * decide parallelism and {@link JobVertexInputInfo}s as follows: * - *

1. For job vertices whose inputs are all ALL_TO_ALL edges, evenly distribute data to - * downstream subtasks, make different downstream subtasks consume roughly the same amount of data. + *

1. We will first attempt to: evenly distribute data to downstream subtasks, make different + * downstream subtasks consume roughly the same amount of data. * - *

2. For other cases, evenly distribute subpartitions to downstream subtasks, make different - * downstream subtasks consume roughly the same number of subpartitions. + *

2. If step 1 fails or is not applicable, we will proceed to: evenly distribute subpartitions + * to downstream subtasks, make different downstream subtasks consume roughly the same number of + * subpartitions. */ public class DefaultVertexParallelismAndInputInfosDecider implements VertexParallelismAndInputInfosDecider { @@ -65,26 +62,20 @@ public class DefaultVertexParallelismAndInputInfosDecider private static final Logger LOG = LoggerFactory.getLogger(DefaultVertexParallelismAndInputInfosDecider.class); - /** - * The maximum number of subpartitions belonging to the same result that each task can consume. - * We currently need this limitation to avoid too many channels in a downstream task leading to - * poor performance. - * - *

TODO: Once we support one channel to consume multiple upstream subpartitions in the - * future, we can remove this limitation - */ - private static final int MAX_NUM_SUBPARTITIONS_PER_TASK_CONSUME = 32768; - private final int globalMaxParallelism; private final int globalMinParallelism; private final long dataVolumePerTask; private final int globalDefaultSourceParallelism; + private final AllToAllVertexInputInfoComputer allToAllVertexInputInfoComputer; + private final PointwiseVertexInputInfoComputer pointwiseVertexInputInfoComputer; private DefaultVertexParallelismAndInputInfosDecider( int globalMaxParallelism, int globalMinParallelism, MemorySize dataVolumePerTask, - int globalDefaultSourceParallelism) { + int globalDefaultSourceParallelism, + double skewedFactor, + long skewedThreshold) { checkArgument(globalMinParallelism > 0, "The minimum parallelism must be larger than 0."); checkArgument( @@ -94,11 +85,17 @@ private DefaultVertexParallelismAndInputInfosDecider( globalDefaultSourceParallelism > 0, "The default source parallelism must be larger than 0."); checkNotNull(dataVolumePerTask); + checkArgument( + skewedFactor > 0, "The default skewed partition factor must be larger than 0."); + checkArgument(skewedThreshold > 0, "The default skewed threshold must be larger than 0."); this.globalMaxParallelism = globalMaxParallelism; this.globalMinParallelism = globalMinParallelism; this.dataVolumePerTask = dataVolumePerTask.getBytes(); this.globalDefaultSourceParallelism = globalDefaultSourceParallelism; + this.allToAllVertexInputInfoComputer = + new AllToAllVertexInputInfoComputer(skewedFactor, skewedThreshold); + this.pointwiseVertexInputInfoComputer = new PointwiseVertexInputInfoComputer(); } @Override @@ -126,52 +123,41 @@ public ParallelismAndInputInfos decideParallelismAndInputInfosForVertex( ? vertexInitialParallelism : computeSourceParallelismUpperBound(jobVertexId, vertexMaxParallelism); return new ParallelismAndInputInfos(parallelism, Collections.emptyMap()); - } else { - int minParallelism = Math.max(globalMinParallelism, vertexMinParallelism); - int maxParallelism = globalMaxParallelism; - - if (vertexInitialParallelism == ExecutionConfig.PARALLELISM_DEFAULT - && vertexMaxParallelism < minParallelism) { - LOG.info( - "The vertex maximum parallelism {} is smaller than the minimum parallelism {}. " - + "Use {} as the lower bound to decide parallelism of job vertex {}.", - vertexMaxParallelism, - minParallelism, - vertexMaxParallelism, - jobVertexId); - minParallelism = vertexMaxParallelism; - } - if (vertexInitialParallelism == ExecutionConfig.PARALLELISM_DEFAULT - && vertexMaxParallelism < maxParallelism) { - LOG.info( - "The vertex maximum parallelism {} is smaller than the global maximum parallelism {}. " - + "Use {} as the upper bound to decide parallelism of job vertex {}.", - vertexMaxParallelism, - maxParallelism, - vertexMaxParallelism, - jobVertexId); - maxParallelism = vertexMaxParallelism; - } - checkState(maxParallelism >= minParallelism); - - if (vertexInitialParallelism == ExecutionConfig.PARALLELISM_DEFAULT - && areAllInputsAllToAll(consumedResults) - && !areAllInputsBroadcast(consumedResults)) { - return decideParallelismAndEvenlyDistributeData( - jobVertexId, - consumedResults, - vertexInitialParallelism, - minParallelism, - maxParallelism); - } else { - return decideParallelismAndEvenlyDistributeSubpartitions( - jobVertexId, - consumedResults, - vertexInitialParallelism, - minParallelism, - maxParallelism); - } } + + int minParallelism = Math.max(globalMinParallelism, vertexMinParallelism); + int maxParallelism = globalMaxParallelism; + + if (vertexInitialParallelism == ExecutionConfig.PARALLELISM_DEFAULT + && vertexMaxParallelism < minParallelism) { + LOG.info( + "The vertex maximum parallelism {} is smaller than the minimum parallelism {}. " + + "Use {} as the lower bound to decide parallelism of job vertex {}.", + vertexMaxParallelism, + minParallelism, + vertexMaxParallelism, + jobVertexId); + minParallelism = vertexMaxParallelism; + } + if (vertexInitialParallelism == ExecutionConfig.PARALLELISM_DEFAULT + && vertexMaxParallelism < maxParallelism) { + LOG.info( + "The vertex maximum parallelism {} is smaller than the global maximum parallelism {}. " + + "Use {} as the upper bound to decide parallelism of job vertex {}.", + vertexMaxParallelism, + maxParallelism, + vertexMaxParallelism, + jobVertexId); + maxParallelism = vertexMaxParallelism; + } + checkState(maxParallelism >= minParallelism); + + return decideParallelismAndInputInfosForNonSource( + jobVertexId, + consumedResults, + vertexInitialParallelism, + minParallelism, + maxParallelism); } @Override @@ -195,43 +181,71 @@ public long getDataVolumePerTask() { return dataVolumePerTask; } - private static boolean areAllInputsAllToAll(List consumedResults) { - return consumedResults.stream().noneMatch(BlockingInputInfo::isPointwise); - } - - private static boolean areAllInputsBroadcast(List consumedResults) { - return consumedResults.stream() - .allMatch(BlockingInputInfo::isSingleSubpartitionContainsAllData); - } - - /** - * Decide parallelism and input infos, which will make the subpartitions be evenly distributed - * to downstream subtasks, such that different downstream subtasks consume roughly the same - * number of subpartitions. - * - * @param jobVertexId The job vertex id - * @param consumedResults The information of consumed blocking results - * @param initialParallelism The initial parallelism of the job vertex - * @param minParallelism the min parallelism - * @param maxParallelism the max parallelism - * @return the parallelism and vertex input infos - */ - private ParallelismAndInputInfos decideParallelismAndEvenlyDistributeSubpartitions( + private ParallelismAndInputInfos decideParallelismAndInputInfosForNonSource( JobVertexID jobVertexId, List consumedResults, - int initialParallelism, + int vertexInitialParallelism, int minParallelism, int maxParallelism) { - checkArgument(!consumedResults.isEmpty()); int parallelism = - initialParallelism > 0 - ? initialParallelism + vertexInitialParallelism > 0 + ? vertexInitialParallelism : decideParallelism( jobVertexId, consumedResults, minParallelism, maxParallelism); + + List pointwiseInputs = new ArrayList<>(); + + List allToAllInputs = new ArrayList<>(); + + consumedResults.forEach( + inputInfo -> { + if (inputInfo.isPointwise()) { + pointwiseInputs.add(inputInfo); + } else { + allToAllInputs.add(inputInfo); + } + }); + + // For AllToAll like inputs, we derive parallelism as a whole, while for Pointwise inputs, + // we derive parallelism separately for each input, and our goal is ensured that the final + // parallelisms of those inputs are consistent and meet expectations. + // Since AllToAll supports deriving parallelism within a flexible range, this might + // interfere with the target parallelism. Therefore, in the following cases, we need to + // reset the minimum and maximum parallelism to limit the flexibility of parallelism + // derivation to achieve the goal: + // 1. Vertex has a specified parallelism, we should follow it. + // 2. There are pointwise inputs, which means that there may be inputs whose parallelism is + // derived one-by-one, we need to reset the min and max parallelism. + if (vertexInitialParallelism > 0 || !pointwiseInputs.isEmpty()) { + minParallelism = parallelism; + maxParallelism = parallelism; + } + + Map vertexInputInfos = new HashMap<>(); + + if (!allToAllInputs.isEmpty()) { + vertexInputInfos.putAll( + allToAllVertexInputInfoComputer.compute( + jobVertexId, + allToAllInputs, + parallelism, + minParallelism, + maxParallelism, + calculateDataVolumePerTaskForInputsGroup( + dataVolumePerTask, allToAllInputs, consumedResults))); + } + + if (!pointwiseInputs.isEmpty()) { + vertexInputInfos.putAll( + pointwiseVertexInputInfoComputer.compute( + pointwiseInputs, + parallelism, + calculateDataVolumePerTaskForInputsGroup( + dataVolumePerTask, pointwiseInputs, consumedResults))); + } + return new ParallelismAndInputInfos( - parallelism, - VertexInputInfoComputationUtils.computeVertexInputInfos( - parallelism, consumedResults, true)); + checkAndGetParallelism(vertexInputInfos.values()), vertexInputInfos); } int decideParallelism( @@ -244,7 +258,7 @@ int decideParallelism( // Considering that the sizes of broadcast results are usually very small, we compute the // parallelism only based on sizes of non-broadcast results final List nonBroadcastResults = - getNonBroadcastResultInfos(consumedResults); + getNonBroadcastInputInfos(consumedResults); if (nonBroadcastResults.isEmpty()) { return minParallelism; } @@ -254,12 +268,6 @@ int decideParallelism( .mapToLong(BlockingInputInfo::getNumBytesProduced) .sum(); int parallelism = (int) Math.ceil((double) totalBytes / dataVolumePerTask); - int minParallelismLimitedByMaxSubpartitions = - (int) - Math.ceil( - (double) getMaxNumSubpartitions(nonBroadcastResults) - / MAX_NUM_SUBPARTITIONS_PER_TASK_CONSUME); - parallelism = Math.max(parallelism, minParallelismLimitedByMaxSubpartitions); LOG.debug( "The total size of non-broadcast data is {}, the initially decided parallelism of job vertex {} is {}.", @@ -290,276 +298,11 @@ int decideParallelism( return parallelism; } - /** - * Decide parallelism and input infos, which will make the data be evenly distributed to - * downstream subtasks, such that different downstream subtasks consume roughly the same amount - * of data. - * - * @param jobVertexId The job vertex id - * @param consumedResults The information of consumed blocking results - * @param initialParallelism The initial parallelism of the job vertex - * @param minParallelism the min parallelism - * @param maxParallelism the max parallelism - * @return the parallelism and vertex input infos - */ - private ParallelismAndInputInfos decideParallelismAndEvenlyDistributeData( - JobVertexID jobVertexId, - List consumedResults, - int initialParallelism, - int minParallelism, - int maxParallelism) { - checkArgument(initialParallelism == ExecutionConfig.PARALLELISM_DEFAULT); - checkArgument(!consumedResults.isEmpty()); - consumedResults.forEach(resultInfo -> checkState(!resultInfo.isPointwise())); - - // Considering that the sizes of broadcast results are usually very small, we compute the - // parallelism and input infos only based on sizes of non-broadcast results - final List nonBroadcastResults = - getNonBroadcastResultInfos(consumedResults); - int subpartitionNum = checkAndGetSubpartitionNum(nonBroadcastResults); - - long[] bytesBySubpartition = new long[subpartitionNum]; - Arrays.fill(bytesBySubpartition, 0L); - for (BlockingInputInfo resultInfo : nonBroadcastResults) { - List subpartitionBytes = resultInfo.getAggregatedSubpartitionBytes(); - for (int i = 0; i < subpartitionNum; ++i) { - bytesBySubpartition[i] += subpartitionBytes.get(i); - } - } - - int maxNumPartitions = getMaxNumPartitions(nonBroadcastResults); - int maxRangeSize = MAX_NUM_SUBPARTITIONS_PER_TASK_CONSUME / maxNumPartitions; - // compute subpartition ranges - List subpartitionRanges = - computeSubpartitionRanges(bytesBySubpartition, dataVolumePerTask, maxRangeSize); - - // if the parallelism is not legal, adjust to a legal parallelism - if (!isLegalParallelism(subpartitionRanges.size(), minParallelism, maxParallelism)) { - Optional> adjustedSubpartitionRanges = - adjustToClosestLegalParallelism( - dataVolumePerTask, - subpartitionRanges.size(), - minParallelism, - maxParallelism, - Arrays.stream(bytesBySubpartition).min().getAsLong(), - Arrays.stream(bytesBySubpartition).sum(), - limit -> computeParallelism(bytesBySubpartition, limit, maxRangeSize), - limit -> - computeSubpartitionRanges( - bytesBySubpartition, limit, maxRangeSize)); - if (!adjustedSubpartitionRanges.isPresent()) { - // can't find any legal parallelism, fall back to evenly distribute subpartitions - LOG.info( - "Cannot find a legal parallelism to evenly distribute data for job vertex {}. " - + "Fall back to compute a parallelism that can evenly distribute subpartitions.", - jobVertexId); - return decideParallelismAndEvenlyDistributeSubpartitions( - jobVertexId, - consumedResults, - initialParallelism, - minParallelism, - maxParallelism); - } - subpartitionRanges = adjustedSubpartitionRanges.get(); - } - - checkState(isLegalParallelism(subpartitionRanges.size(), minParallelism, maxParallelism)); - return createParallelismAndInputInfos(consumedResults, subpartitionRanges); - } - - private static boolean isLegalParallelism( - int parallelism, int minParallelism, int maxParallelism) { - return parallelism >= minParallelism && parallelism <= maxParallelism; - } - - private static int checkAndGetSubpartitionNum(List consumedResults) { - final Set subpartitionNumSet = - consumedResults.stream() - .flatMap( - resultInfo -> - IntStream.range(0, resultInfo.getNumPartitions()) - .boxed() - .map(resultInfo::getNumSubpartitions)) - .collect(Collectors.toSet()); - // all partitions have the same subpartition num - checkState(subpartitionNumSet.size() == 1); - return subpartitionNumSet.iterator().next(); - } - - /** - * Adjust the parallelism to the closest legal parallelism and return the computed subpartition - * ranges. - * - * @param currentDataVolumeLimit current data volume limit - * @param currentParallelism current parallelism - * @param minParallelism the min parallelism - * @param maxParallelism the max parallelism - * @param minLimit the minimum data volume limit - * @param maxLimit the maximum data volume limit - * @param parallelismComputer a function to compute the parallelism according to the data volume - * limit - * @param subpartitionRangesComputer a function to compute the subpartition ranges according to - * the data volume limit - * @return the computed subpartition ranges or {@link Optional#empty()} if we can't find any - * legal parallelism - */ - private static Optional> adjustToClosestLegalParallelism( - long currentDataVolumeLimit, - int currentParallelism, - int minParallelism, - int maxParallelism, - long minLimit, - long maxLimit, - Function parallelismComputer, - Function> subpartitionRangesComputer) { - long adjustedDataVolumeLimit = currentDataVolumeLimit; - if (currentParallelism < minParallelism) { - // Current parallelism is smaller than the user-specified lower-limit of parallelism , - // we need to adjust it to the closest/minimum possible legal parallelism. That is, we - // need to find the maximum legal dataVolumeLimit. - adjustedDataVolumeLimit = - BisectionSearchUtils.findMaxLegalValue( - value -> parallelismComputer.apply(value) >= minParallelism, - minLimit, - currentDataVolumeLimit); - - // When we find the minimum possible legal parallelism, the dataVolumeLimit that can - // lead to this parallelism may be a range, and we need to find the minimum value of - // this range to make the data distribution as even as possible (the smaller the - // dataVolumeLimit, the more even the distribution) - final long minPossibleLegalParallelism = - parallelismComputer.apply(adjustedDataVolumeLimit); - adjustedDataVolumeLimit = - BisectionSearchUtils.findMinLegalValue( - value -> - parallelismComputer.apply(value) == minPossibleLegalParallelism, - minLimit, - adjustedDataVolumeLimit); - - } else if (currentParallelism > maxParallelism) { - // Current parallelism is larger than the user-specified upper-limit of parallelism , - // we need to adjust it to the closest/maximum possible legal parallelism. That is, we - // need to find the minimum legal dataVolumeLimit. - adjustedDataVolumeLimit = - BisectionSearchUtils.findMinLegalValue( - value -> parallelismComputer.apply(value) <= maxParallelism, - currentDataVolumeLimit, - maxLimit); - } - - int adjustedParallelism = parallelismComputer.apply(adjustedDataVolumeLimit); - if (isLegalParallelism(adjustedParallelism, minParallelism, maxParallelism)) { - return Optional.of(subpartitionRangesComputer.apply(adjustedDataVolumeLimit)); - } else { - return Optional.empty(); - } - } - - private static ParallelismAndInputInfos createParallelismAndInputInfos( - List consumedResults, List subpartitionRanges) { - - final Map vertexInputInfos = new HashMap<>(); - consumedResults.forEach( - resultInfo -> { - int sourceParallelism = resultInfo.getNumPartitions(); - IndexRange partitionRange = new IndexRange(0, sourceParallelism - 1); - - List executionVertexInputInfos = new ArrayList<>(); - for (int i = 0; i < subpartitionRanges.size(); ++i) { - IndexRange subpartitionRange; - if (resultInfo.isBroadcast()) { - if (resultInfo.isSingleSubpartitionContainsAllData()) { - subpartitionRange = new IndexRange(0, 0); - } else { - // The partitions of the all-to-all result have the same number of - // subpartitions. So we can use the first partition's subpartition - // number. - subpartitionRange = - new IndexRange(0, resultInfo.getNumSubpartitions(0) - 1); - } - } else { - subpartitionRange = subpartitionRanges.get(i); - } - ExecutionVertexInputInfo executionVertexInputInfo = - new ExecutionVertexInputInfo(i, partitionRange, subpartitionRange); - executionVertexInputInfos.add(executionVertexInputInfo); - } - - vertexInputInfos.put( - resultInfo.getResultId(), - new JobVertexInputInfo(executionVertexInputInfos)); - }); - return new ParallelismAndInputInfos(subpartitionRanges.size(), vertexInputInfos); - } - - private static List computeSubpartitionRanges( - long[] nums, long limit, int maxRangeSize) { - List subpartitionRanges = new ArrayList<>(); - long tmpSum = 0; - int startIndex = 0; - for (int i = 0; i < nums.length; ++i) { - long num = nums[i]; - if (i == startIndex - || (tmpSum + num <= limit && (i - startIndex + 1) <= maxRangeSize)) { - tmpSum += num; - } else { - subpartitionRanges.add(new IndexRange(startIndex, i - 1)); - startIndex = i; - tmpSum = num; - } - } - subpartitionRanges.add(new IndexRange(startIndex, nums.length - 1)); - return subpartitionRanges; - } - - private static int computeParallelism(long[] nums, long limit, int maxRangeSize) { - long tmpSum = 0; - int startIndex = 0; - int count = 1; - for (int i = 0; i < nums.length; ++i) { - long num = nums[i]; - if (i == startIndex - || (tmpSum + num <= limit && (i - startIndex + 1) <= maxRangeSize)) { - tmpSum += num; - } else { - startIndex = i; - tmpSum = num; - count += 1; - } - } - return count; - } - - private static int getMaxNumPartitions(List consumedResults) { - checkArgument(!consumedResults.isEmpty()); - return consumedResults.stream() - .mapToInt(BlockingInputInfo::getNumPartitions) - .max() - .getAsInt(); - } - - private static int getMaxNumSubpartitions(List consumedResults) { - checkArgument(!consumedResults.isEmpty()); - return consumedResults.stream() - .mapToInt( - resultInfo -> - IntStream.range(0, resultInfo.getNumPartitions()) - .boxed() - .mapToInt(resultInfo::getNumSubpartitions) - .sum()) - .max() - .getAsInt(); - } - - private static List getNonBroadcastResultInfos( - List consumedResults) { - return consumedResults.stream() - .filter(resultInfo -> !resultInfo.isSingleSubpartitionContainsAllData()) - .collect(Collectors.toList()); - } - static DefaultVertexParallelismAndInputInfosDecider from( - int maxParallelism, Configuration configuration) { + int maxParallelism, + double skewedFactor, + long skewedThreshold, + Configuration configuration) { return new DefaultVertexParallelismAndInputInfosDecider( maxParallelism, configuration.get(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MIN_PARALLELISM), @@ -567,6 +310,8 @@ static DefaultVertexParallelismAndInputInfosDecider from( BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK), configuration.get( BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, - maxParallelism)); + maxParallelism), + skewedFactor, + skewedThreshold); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfo.java new file mode 100644 index 0000000000000..c01008f02623f --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfo.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.checkAndGetIntraCorrelation; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.computeTargetSize; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.getMaxNumPartitions; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.hasSameNumPartitions; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.median; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Helper class that aggregates input information with the same typeNumber so that they can be + * processed as a single unit. + */ +public class AggregatedBlockingInputInfo { + private static final Logger LOG = LoggerFactory.getLogger(AggregatedBlockingInputInfo.class); + + /** The maximum number of partitions among all aggregated inputs. */ + private final int maxPartitionNum; + + /** The threshold used to determine if a specific aggregated subpartition is skewed. */ + private final long skewedThreshold; + + /** The target size for splitting skewed aggregated subpartitions. */ + private final long targetSize; + + /** + * Indicates whether the records corresponding to the same key must be sent to the same + * downstream subtask. + */ + private final boolean intraInputKeyCorrelated; + + /** + * A map where the key is the partition index and the value is an array representing the size of + * each subpartition for that partition. This map is used to provide fine-grained information + * for splitting subpartitions with same index. If it is empty, means that the split operation + * cannot be performed. In the following cases, this map will be empty: 1. + * IntraInputKeyCorrelated is true. 2. The aggregated input infos have different num partitions. + * 3. The SubpartitionBytesByPartitionIndex of inputs is empty. + */ + private final Map subpartitionBytesByPartition; + + /** + * An array representing the aggregated size of each subpartition across all partitions. Each + * element in the array corresponds to a subpartition. + */ + private final long[] aggregatedSubpartitionBytes; + + private AggregatedBlockingInputInfo( + long targetSize, + long skewedThreshold, + int maxPartitionNum, + boolean intraInputKeyCorrelated, + Map subpartitionBytesByPartition, + long[] aggregatedSubpartitionBytes) { + this.maxPartitionNum = maxPartitionNum; + this.skewedThreshold = skewedThreshold; + this.targetSize = targetSize; + this.intraInputKeyCorrelated = intraInputKeyCorrelated; + this.subpartitionBytesByPartition = checkNotNull(subpartitionBytesByPartition); + this.aggregatedSubpartitionBytes = checkNotNull(aggregatedSubpartitionBytes); + } + + public int getMaxPartitionNum() { + return maxPartitionNum; + } + + public long getTargetSize() { + return targetSize; + } + + public Map getSubpartitionBytesByPartition() { + return Collections.unmodifiableMap(subpartitionBytesByPartition); + } + + public long getAggregatedSubpartitionBytes(int subpartitionIndex) { + return aggregatedSubpartitionBytes[subpartitionIndex]; + } + + public boolean isSplittable() { + return !intraInputKeyCorrelated && !subpartitionBytesByPartition.isEmpty(); + } + + public boolean isSkewedSubpartition(int subpartitionIndex) { + return aggregatedSubpartitionBytes[subpartitionIndex] > skewedThreshold; + } + + public int getNumSubpartitions() { + return aggregatedSubpartitionBytes.length; + } + + private static long[] computeAggregatedSubpartitionBytes( + List inputInfos, int subpartitionNum) { + long[] aggregatedSubpartitionBytes = new long[subpartitionNum]; + for (BlockingInputInfo inputInfo : inputInfos) { + List subpartitionBytes = inputInfo.getAggregatedSubpartitionBytes(); + for (int i = 0; i < subpartitionBytes.size(); i++) { + aggregatedSubpartitionBytes[i] += subpartitionBytes.get(i); + } + } + return aggregatedSubpartitionBytes; + } + + private static Map computeSubpartitionBytesByPartitionIndex( + List inputInfos, int subpartitionNum) { + // If inputInfos have different num partitions (means that these upstream have different + // parallelisms), return an empty result to disable data splitting. + if (!hasSameNumPartitions(inputInfos)) { + LOG.warn( + "Input infos have different num partitions, skip calculate SubpartitionBytesByPartitionIndex"); + return Collections.emptyMap(); + } + Map subpartitionBytesByPartitionIndex = new HashMap<>(); + for (BlockingInputInfo inputInfo : inputInfos) { + inputInfo + .getSubpartitionBytesByPartitionIndex() + .forEach( + (partitionIdx, subPartitionBytes) -> { + long[] subpartitionBytes = + subpartitionBytesByPartitionIndex.computeIfAbsent( + partitionIdx, v -> new long[subpartitionNum]); + for (int i = 0; i < subpartitionNum; i++) { + subpartitionBytes[i] += subPartitionBytes[i]; + } + }); + } + return subpartitionBytesByPartitionIndex; + } + + public static AggregatedBlockingInputInfo createAggregatedBlockingInputInfo( + long defaultSkewedThreshold, + double skewedFactor, + long dataVolumePerTask, + List inputInfos) { + int subpartitionNum = checkAndGetSubpartitionNum(inputInfos); + long[] aggregatedSubpartitionBytes = + computeAggregatedSubpartitionBytes(inputInfos, subpartitionNum); + long skewedThreshold = + computeSkewThreshold( + median(aggregatedSubpartitionBytes), skewedFactor, defaultSkewedThreshold); + long targetSize = + computeTargetSize(aggregatedSubpartitionBytes, skewedThreshold, dataVolumePerTask); + boolean isIntraInputKeyCorrelated = checkAndGetIntraCorrelation(inputInfos); + Map subpartitionBytesByPartitionIndex; + if (isIntraInputKeyCorrelated) { + // subpartitions with same index will not be split, skipped calculate it + subpartitionBytesByPartitionIndex = new HashMap<>(); + } else { + subpartitionBytesByPartitionIndex = + computeSubpartitionBytesByPartitionIndex(inputInfos, subpartitionNum); + } + return new AggregatedBlockingInputInfo( + targetSize, + skewedThreshold, + getMaxNumPartitions(inputInfos), + isIntraInputKeyCorrelated, + subpartitionBytesByPartitionIndex, + aggregatedSubpartitionBytes); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputer.java new file mode 100644 index 0000000000000..ffae30b677c22 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputer.java @@ -0,0 +1,486 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.SubpartitionSlice.createSubpartitionSlice; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.SubpartitionSlice.createSubpartitionSlicesByMultiPartitionRanges; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInput; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInputsGroup; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.cartesianProduct; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.checkAndGetParallelism; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNum; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.checkAndGetSubpartitionNumForAggregatedInputs; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForBroadcast; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForNonBroadcast; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.getNonBroadcastInputInfos; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange; +import static org.apache.flink.util.Preconditions.checkState; + +/** Helper class that computes VertexInputInfo for all to all like inputs. */ +public class AllToAllVertexInputInfoComputer { + private static final Logger LOG = + LoggerFactory.getLogger(AllToAllVertexInputInfoComputer.class); + + private final double skewedFactor; + private final long defaultSkewedThreshold; + + public AllToAllVertexInputInfoComputer(double skewedFactor, long defaultSkewedThreshold) { + this.skewedFactor = skewedFactor; + this.defaultSkewedThreshold = defaultSkewedThreshold; + } + + /** + * Decide parallelism and input infos, which will make the data be evenly distributed to + * downstream subtasks for ALL_TO_ALL, such that different downstream subtasks consume roughly + * the same amount of data. + * + *

Assume there are two input infos upstream, each with three partitions and two + * subpartitions, their data bytes information are: input1: 0->[1,1] 1->[2,2] 2->[3,3], input2: + * 0->[1,1] 1->[1,1] 2->[1,1]. This method processes the data as follows:
+ * 1. Create subpartition slices for inputs with same type number, different from pointwise + * computer, this method creates subpartition slices by following these steps: Firstly, + * reorganize the data by subpartition index: input1: {0->[1,2,3],1->[1,2,3]}, input2: + * {0->[1,1,1],1->[1,1,1]}. Secondly, split subpartitions with the same index into relatively + * balanced n parts (if possible): {0->[1,2][3],1->[1,2][3]}, {0->[1,1,1],1->[1,1,1]}. Then + * perform a cartesian product operation to ensure data correctness input1: + * {0->[1,2],0->[3],1->[1,2],1->[3]}, input2: {0->[1,1,1],0->[1,1,1],1->[1,1,1],1->[1,1,1]}, + * Finally, create subpartition slices base on the result of the previous step. i.e., each input + * has four balanced subpartition slices.
+ * 2. Based on the above subpartition slices, calculate the subpartition slice range each task + * needs to subscribe to, considering data volume and parallelism constraints: + * [0,0],[1,1],[2,2],[3,3]
+ * 3. Convert the calculated subpartition slice range to the form of partition index range -> + * subpartition index range:
+ * task0: input1: {[0,1]->[0]} input2:{[0,2]->[0]}
+ * task1: input1: {[2,2]->[0]} input2:{[0,2]->[0]}
+ * task2: input1: {[0,1]->[1]} input2:{[0,2]->[1]}
+ * task3: input1: {[2,2]->[1]} input2:{[0,2]->[1]} + * + * @param jobVertexId The job vertex id + * @param inputInfos The information of consumed blocking results + * @param parallelism The parallelism of the job vertex + * @param minParallelism the min parallelism + * @param maxParallelism the max parallelism + * @param dataVolumePerTask proposed data volume per task for this set of inputInfo + * @return the parallelism and vertex input infos + */ + public Map compute( + JobVertexID jobVertexId, + List inputInfos, + int parallelism, + int minParallelism, + int maxParallelism, + long dataVolumePerTask) { + // For inputs with inter-keys correlation should be process together, as there is a + // correlation between them. For inputs without inter-keys, we should handle them + // separately. + List inputInfosWithoutInterKeysCorrelation = new ArrayList<>(); + List inputInfosWithInterKeysCorrelation = new ArrayList<>(); + for (BlockingInputInfo inputInfo : inputInfos) { + if (inputInfo.areInterInputsKeysCorrelated()) { + inputInfosWithInterKeysCorrelation.add(inputInfo); + } else { + inputInfosWithoutInterKeysCorrelation.add(inputInfo); + } + } + + Map vertexInputInfos = new HashMap<>(); + if (!inputInfosWithInterKeysCorrelation.isEmpty()) { + vertexInputInfos.putAll( + computeJobVertexInputInfosForInputsWithInterKeysCorrelation( + jobVertexId, + inputInfosWithInterKeysCorrelation, + parallelism, + minParallelism, + maxParallelism, + calculateDataVolumePerTaskForInputsGroup( + dataVolumePerTask, + inputInfosWithInterKeysCorrelation, + inputInfos))); + // Ensure the parallelism of inputs without inter and intra correlations is + // consistent with decided parallelism. + parallelism = checkAndGetParallelism(vertexInputInfos.values()); + } + + if (!inputInfosWithoutInterKeysCorrelation.isEmpty()) { + vertexInputInfos.putAll( + computeJobVertexInputInfosForInputsWithoutInterKeysCorrelation( + inputInfosWithoutInterKeysCorrelation, + parallelism, + calculateDataVolumePerTaskForInputsGroup( + dataVolumePerTask, + inputInfosWithoutInterKeysCorrelation, + inputInfos))); + } + + return vertexInputInfos; + } + + private Map + computeJobVertexInputInfosForInputsWithInterKeysCorrelation( + JobVertexID jobVertexId, + List inputInfos, + int parallelism, + int minParallelism, + int maxParallelism, + long dataVolumePerTask) { + List nonBroadcastInputInfos = getNonBroadcastInputInfos(inputInfos); + if (nonBroadcastInputInfos.isEmpty()) { + LOG.info( + "All inputs are broadcast for vertex {}, fallback to compute a parallelism that can evenly distribute num subpartitions.", + jobVertexId); + // This computer is only used in the adaptive batch scenario, where isDynamicGraph + // should always be true. + return VertexInputInfoComputationUtils.computeVertexInputInfos( + parallelism, inputInfos, true); + } + + // Divide the data into balanced n parts and describe each part by SubpartitionSlice. + Map> subpartitionSlicesByTypeNumber = + createSubpartitionSlicesForInputsWithInterKeysCorrelation( + nonBroadcastInputInfos, dataVolumePerTask); + + // Distribute the input data evenly among the downstream tasks and record the + // subpartition slice range for each task. + Optional> optionalSubpartitionSliceRanges = + tryComputeSubpartitionSliceRange( + minParallelism, + maxParallelism, + dataVolumePerTask, + subpartitionSlicesByTypeNumber); + + if (optionalSubpartitionSliceRanges.isEmpty()) { + LOG.info( + "Cannot find a legal parallelism to evenly distribute data amount for job vertex {}, " + + "fallback to compute a parallelism that can evenly distribute num subpartitions.", + jobVertexId); + // This computer is only used in the adaptive batch scenario, where isDynamicGraph + // should always be true. + return VertexInputInfoComputationUtils.computeVertexInputInfos( + parallelism, inputInfos, true); + } + + List subpartitionSliceRanges = optionalSubpartitionSliceRanges.get(); + + checkState( + isLegalParallelism(subpartitionSliceRanges.size(), minParallelism, maxParallelism)); + + // Create vertex input info based on the subpartition slice and its range. + return createJobVertexInputInfos( + inputInfos, subpartitionSlicesByTypeNumber, subpartitionSliceRanges); + } + + private Map> + createSubpartitionSlicesForInputsWithInterKeysCorrelation( + List nonBroadcastInputInfos, long dataVolumePerTask) { + // Aggregate input info with the same type number. + Map aggregatedInputInfoByTypeNumber = + createAggregatedBlockingInputInfos(nonBroadcastInputInfos, dataVolumePerTask); + int subPartitionNum = + checkAndGetSubpartitionNumForAggregatedInputs( + aggregatedInputInfoByTypeNumber.values()); + Map> subpartitionSliceGroupByTypeNumber = new HashMap<>(); + for (int subpartitionIndex = 0; subpartitionIndex < subPartitionNum; ++subpartitionIndex) { + // Split the given subpartition group into balanced subpartition slices. + Map> subpartitionSlices = + createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation( + subpartitionIndex, aggregatedInputInfoByTypeNumber); + + List typeNumberList = new ArrayList<>(subpartitionSlices.keySet()); + + List> originalRangeLists = + new ArrayList<>(subpartitionSlices.values()); + + // Perform the Cartesian product for inputs with inter-inputs key correlation. + List> cartesianProductRangeList = + cartesianProduct(originalRangeLists); + + for (List subpartitionSlice : cartesianProductRangeList) { + for (int j = 0; j < subpartitionSlice.size(); ++j) { + int typeNumber = typeNumberList.get(j); + subpartitionSliceGroupByTypeNumber + .computeIfAbsent(typeNumber, ignored -> new ArrayList<>()) + .add(subpartitionSlice.get(j)); + } + } + } + + return subpartitionSliceGroupByTypeNumber; + } + + private Map createAggregatedBlockingInputInfos( + List nonBroadcastInputInfos, long dataVolumePerTask) { + Map> inputsByTypeNumber = + nonBroadcastInputInfos.stream() + .collect(Collectors.groupingBy(BlockingInputInfo::getInputTypeNumber)); + // Inputs with the same type number should be data with the same type, as operators will + // process them in the same way. Currently, they can be considered to must have the same + // IntraInputKeyCorrelation + checkState(hasSameIntraInputKeyCorrelation(inputsByTypeNumber)); + + Map blockingInputInfoContexts = new HashMap<>(); + for (Map.Entry> entry : inputsByTypeNumber.entrySet()) { + Integer typeNumber = entry.getKey(); + List inputInfos = entry.getValue(); + blockingInputInfoContexts.put( + typeNumber, + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + defaultSkewedThreshold, skewedFactor, dataVolumePerTask, inputInfos)); + } + + return blockingInputInfoContexts; + } + + /** + * Creates balanced subpartition slices for inputs with inter-key correlations. + * + *

This method generates a mapping of subpartition indices to lists of subpartition slices, + * ensuring balanced distribution of input data. When a subpartition is splittable and has data + * skew, we will split it into n continuous and balanced parts (by split its partition range). + * If the input is not splittable, this step will be skipped, and subpartitions with the same + * index will be aggregated into a single SubpartitionSlice. + * + * @param subpartitionIndex the index of the subpartition being processed. + * @param aggregatedInputInfoByTypeNumber a map of aggregated blocking input info, keyed by + * input type number. + * @return a map where the key is the input type number and the value is a list of subpartition + * slices for the specified subpartition. + */ + private static Map> + createBalancedSubpartitionSlicesForInputsWithInterKeysCorrelation( + int subpartitionIndex, + Map aggregatedInputInfoByTypeNumber) { + Map> subpartitionSlices = new HashMap<>(); + IndexRange subpartitionRange = new IndexRange(subpartitionIndex, subpartitionIndex); + for (Map.Entry entry : + aggregatedInputInfoByTypeNumber.entrySet()) { + Integer typeNumber = entry.getKey(); + AggregatedBlockingInputInfo aggregatedBlockingInputInfo = entry.getValue(); + if (aggregatedBlockingInputInfo.isSplittable() + && aggregatedBlockingInputInfo.isSkewedSubpartition(subpartitionIndex)) { + List partitionRanges = + computePartitionRangesEvenlyData( + subpartitionIndex, + aggregatedBlockingInputInfo.getTargetSize(), + aggregatedBlockingInputInfo.getSubpartitionBytesByPartition()); + subpartitionSlices.put( + typeNumber, + createSubpartitionSlicesByMultiPartitionRanges( + partitionRanges, + subpartitionRange, + aggregatedBlockingInputInfo.getSubpartitionBytesByPartition())); + } else { + IndexRange partitionRange = + new IndexRange(0, aggregatedBlockingInputInfo.getMaxPartitionNum() - 1); + subpartitionSlices.put( + typeNumber, + Collections.singletonList( + SubpartitionSlice.createSubpartitionSlice( + partitionRange, + subpartitionRange, + aggregatedBlockingInputInfo.getAggregatedSubpartitionBytes( + subpartitionIndex)))); + } + } + return subpartitionSlices; + } + + /** + * Splits a group of subpartitions with the same subpartition index into balanced slices based + * on the target size and returns the corresponding partition ranges. + * + * @param subPartitionIndex The index of the subpartition to be split. + * @param targetSize The target size for each slice. + * @param subPartitionBytesByPartitionIndex The byte size information of subpartitions in each + * partition, with the partition index as the key and the byte array as the value. + * @return A list of {@link IndexRange} objects representing the partition ranges of each slice. + */ + private static List computePartitionRangesEvenlyData( + int subPartitionIndex, + long targetSize, + Map subPartitionBytesByPartitionIndex) { + List splitPartitionRange = new ArrayList<>(); + int partitionNum = subPartitionBytesByPartitionIndex.size(); + long tmpSum = 0; + int startIndex = 0; + for (int i = 0; i < partitionNum; ++i) { + long[] subPartitionBytes = subPartitionBytesByPartitionIndex.get(i); + long num = subPartitionBytes[subPartitionIndex]; + if (i == startIndex || tmpSum + num < targetSize) { + tmpSum += num; + } else { + splitPartitionRange.add(new IndexRange(startIndex, i - 1)); + startIndex = i; + tmpSum = num; + } + } + splitPartitionRange.add(new IndexRange(startIndex, partitionNum - 1)); + return splitPartitionRange; + } + + private static Map createJobVertexInputInfos( + List inputInfos, + Map> subpartitionSlices, + List subpartitionSliceRanges) { + final Map vertexInputInfos = new HashMap<>(); + for (BlockingInputInfo inputInfo : inputInfos) { + if (inputInfo.isBroadcast()) { + vertexInputInfos.put( + inputInfo.getResultId(), + createdJobVertexInputInfoForBroadcast( + inputInfo, subpartitionSliceRanges.size())); + } else { + vertexInputInfos.put( + inputInfo.getResultId(), + createdJobVertexInputInfoForNonBroadcast( + inputInfo, + subpartitionSliceRanges, + subpartitionSlices.get(inputInfo.getInputTypeNumber()))); + } + } + return vertexInputInfos; + } + + private Map + computeJobVertexInputInfosForInputsWithoutInterKeysCorrelation( + List inputInfos, int parallelism, long dataVolumePerTask) { + long totalDataBytes = + inputInfos.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum(); + Map vertexInputInfos = new HashMap<>(); + // For inputs without inter-keys, we should process them one-by-one. + for (BlockingInputInfo inputInfo : inputInfos) { + vertexInputInfos.put( + inputInfo.getResultId(), + computeVertexInputInfoForInputWithoutInterKeysCorrelation( + inputInfo, + parallelism, + calculateDataVolumePerTaskForInput( + dataVolumePerTask, + inputInfo.getNumBytesProduced(), + totalDataBytes))); + } + return vertexInputInfos; + } + + private JobVertexInputInfo computeVertexInputInfoForInputWithoutInterKeysCorrelation( + BlockingInputInfo inputInfo, int parallelism, long dataVolumePerTask) { + if (inputInfo.isBroadcast()) { + return createdJobVertexInputInfoForBroadcast(inputInfo, parallelism); + } + + List subpartitionSlices = + createSubpartitionSlicesForInputWithoutInterKeysCorrelation(inputInfo); + // Node: SubpartitionSliceRanges does not represent the real index of the subpartitions, but + // the location of that subpartition in all subpartitions, as we aggregate all subpartitions + // into a one-digit array to calculate. + Optional> optionalSubpartitionSliceRanges = + tryComputeSubpartitionSliceRange( + parallelism, + parallelism, + dataVolumePerTask, + Map.of(inputInfo.getInputTypeNumber(), subpartitionSlices)); + + if (optionalSubpartitionSliceRanges.isEmpty()) { + LOG.info( + "Cannot find a legal parallelism to evenly distribute data amount for input {}, " + + "fallback to compute a parallelism that can evenly distribute num subpartitions.", + inputInfo.getResultId()); + return VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise( + inputInfo.getNumPartitions(), + parallelism, + inputInfo::getNumSubpartitions, + true); + } + + List subpartitionSliceRanges = optionalSubpartitionSliceRanges.get(); + + checkState(isLegalParallelism(subpartitionSliceRanges.size(), parallelism, parallelism)); + + // Create vertex input info based on the subpartition slice and ranges. + return createdJobVertexInputInfoForNonBroadcast( + inputInfo, subpartitionSliceRanges, subpartitionSlices); + } + + private List createSubpartitionSlicesForInputWithoutInterKeysCorrelation( + BlockingInputInfo inputInfo) { + List subpartitionSlices = new ArrayList<>(); + if (inputInfo.isIntraInputKeyCorrelated()) { + // If the input has intra-input correlation, we need to ensure all subpartitions + // with same index are assigned to the same downstream concurrent task. + + // The number of subpartitions of all partitions for all to all blocking result info + // should be consistent. + int numSubpartitions = checkAndGetSubpartitionNum(List.of(inputInfo)); + IndexRange partitionRange = new IndexRange(0, inputInfo.getNumPartitions() - 1); + for (int i = 0; i < numSubpartitions; ++i) { + IndexRange subpartitionRange = new IndexRange(i, i); + subpartitionSlices.add( + createSubpartitionSlice( + partitionRange, + subpartitionRange, + inputInfo.getNumBytesProduced(partitionRange, subpartitionRange))); + } + } else { + for (int i = 0; i < inputInfo.getNumPartitions(); ++i) { + IndexRange partitionRange = new IndexRange(i, i); + for (int j = 0; j < inputInfo.getNumSubpartitions(i); ++j) { + IndexRange subpartitionRange = new IndexRange(j, j); + subpartitionSlices.add( + createSubpartitionSlice( + partitionRange, + subpartitionRange, + inputInfo.getNumBytesProduced( + partitionRange, subpartitionRange))); + } + } + } + return subpartitionSlices; + } + + private static boolean hasSameIntraInputKeyCorrelation( + Map> inputGroups) { + return inputGroups.values().stream() + .allMatch( + inputs -> + inputs.stream() + .map(BlockingInputInfo::isIntraInputKeyCorrelated) + .distinct() + .count() + == 1); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputer.java new file mode 100644 index 0000000000000..4397bea6a26b3 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputer.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.executiongraph.VertexInputInfoComputationUtils; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.SubpartitionSlice.createSubpartitionSlice; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.calculateDataVolumePerTaskForInput; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.createdJobVertexInputInfoForNonBroadcast; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.isLegalParallelism; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange; +import static org.apache.flink.util.Preconditions.checkState; + +/** Helper class that computes VertexInputInfo for pointwise input. */ +public class PointwiseVertexInputInfoComputer { + private static final Logger LOG = + LoggerFactory.getLogger(PointwiseVertexInputInfoComputer.class); + + /** + * Computes the input information for a job vertex based on the provided blocking input + * information and parallelism. + * + * @param inputInfos List of blocking input information for the job vertex. + * @param parallelism Parallelism of the job vertex. + * @param dataVolumePerTask Proposed data volume per task for this set of inputInfo. + * @return A map of intermediate data set IDs to their corresponding job vertex input + * information. + */ + public Map compute( + List inputInfos, int parallelism, long dataVolumePerTask) { + long totalDataBytes = + inputInfos.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum(); + Map vertexInputInfos = new HashMap<>(); + for (BlockingInputInfo inputInfo : inputInfos) { + // Currently, we consider all inputs in this method must don't have inter-inputs key + // correlation. If other possibilities are introduced in the future, please add new + // branches to this method. + checkState(!inputInfo.areInterInputsKeysCorrelated()); + if (inputInfo.isIntraInputKeyCorrelated()) { + // In this case, we won't split subpartitions within the same partition, so need + // to ensure NumPartitions >= parallelism. + checkState(parallelism <= inputInfo.getNumPartitions()); + } + vertexInputInfos.put( + inputInfo.getResultId(), + computeVertexInputInfo( + inputInfo, + parallelism, + calculateDataVolumePerTaskForInput( + dataVolumePerTask, + inputInfo.getNumBytesProduced(), + totalDataBytes))); + } + return vertexInputInfos; + } + + /** + * Decide parallelism and input infos, which will make the data be evenly distributed to + * downstream subtasks for POINTWISE, such that different downstream subtasks consume roughly + * the same amount of data. + * + *

Assume that `inputInfo` has two partitions, each partition has three subpartitions, their + * data bytes are: {0->[1,2,1], 1->[2,1,2]}, and the expected parallelism is 3. The calculation + * process is as follows:
+ * 1. Create subpartition slices for input which is composed of several subpartitions. The + * created slice list and its data bytes are: [1,2,1,2,1,2]
+ * 2. Distribute the subpartition slices array into n balanced parts (described by `IndexRange`, + * named SubpartitionSliceRanges) based on data volume: [0,1],[2,3],[4,5]
+ * 3. Reorganize the distributed results into a mapping of partition range to subpartition + * range: {0 -> [0,1]}, {0->[2,2],1->[0,0]}, {1->[1,2]}.
+ * The final result is the `SubpartitionGroup` that each of the three parallel tasks need to + * subscribe. + * + * @param inputInfo The information of consumed blocking results + * @param parallelism The parallelism of the job vertex. Since pointwise inputs always compute + * vertex input info one-by-one, we need a determined parallelism to ensure the final + * decided parallelism for all inputs is consistent. + * @return the vertex input info + */ + private static JobVertexInputInfo computeVertexInputInfo( + BlockingInputInfo inputInfo, int parallelism, long dataVolumePerTask) { + List subpartitionSlices = createSubpartitionSlices(inputInfo); + + // Node: SubpartitionSliceRanges does not represent the real index of the subpartitions, but + // the location of that subpartition in all subpartitions, as we aggregate all subpartitions + // into a one-digit array to calculate. + Optional> optionalSubpartitionSliceRanges = + tryComputeSubpartitionSliceRange( + parallelism, + parallelism, + dataVolumePerTask, + Map.of(inputInfo.getInputTypeNumber(), subpartitionSlices)); + + if (optionalSubpartitionSliceRanges.isEmpty()) { + LOG.info( + "Cannot find a legal parallelism to evenly distribute data amount for input {}, " + + "fallback to compute a parallelism that can evenly distribute num subpartitions.", + inputInfo.getResultId()); + // This computer is only used in the adaptive batch scenario, where isDynamicGraph + // should always be true. + return VertexInputInfoComputationUtils.computeVertexInputInfoForPointwise( + inputInfo.getNumPartitions(), + parallelism, + inputInfo::getNumSubpartitions, + true); + } + + List subpartitionSliceRanges = optionalSubpartitionSliceRanges.get(); + + checkState(isLegalParallelism(subpartitionSliceRanges.size(), parallelism, parallelism)); + + // Create vertex input info based on the subpartition slice and ranges. + return createJobVertexInputInfo(inputInfo, subpartitionSliceRanges, subpartitionSlices); + } + + private static List createSubpartitionSlices(BlockingInputInfo inputInfo) { + List subpartitionSlices = new ArrayList<>(); + if (inputInfo.isIntraInputKeyCorrelated()) { + // If the input has intra-input correlation, we need to ensure all subpartitions + // in the same partition index are assigned to the same downstream concurrent task. + for (int i = 0; i < inputInfo.getNumPartitions(); ++i) { + IndexRange partitionRange = new IndexRange(i, i); + IndexRange subpartitionRange = + new IndexRange(0, inputInfo.getNumSubpartitions(i) - 1); + subpartitionSlices.add( + createSubpartitionSlice( + partitionRange, + subpartitionRange, + inputInfo.getNumBytesProduced(partitionRange, subpartitionRange))); + } + } else { + for (int i = 0; i < inputInfo.getNumPartitions(); ++i) { + IndexRange partitionRange = new IndexRange(i, i); + for (int j = 0; j < inputInfo.getNumSubpartitions(i); ++j) { + IndexRange subpartitionRange = new IndexRange(j, j); + subpartitionSlices.add( + createSubpartitionSlice( + partitionRange, + subpartitionRange, + inputInfo.getNumBytesProduced( + partitionRange, subpartitionRange))); + } + } + } + return subpartitionSlices; + } + + private static JobVertexInputInfo createJobVertexInputInfo( + BlockingInputInfo inputInfo, + List subpartitionSliceRanges, + List subpartitionSlices) { + checkState(!inputInfo.isBroadcast()); + return createdJobVertexInputInfoForNonBroadcast( + inputInfo, subpartitionSliceRanges, subpartitionSlices); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSlice.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSlice.java new file mode 100644 index 0000000000000..9a4ec5bf0bd10 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSlice.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Helper class that describes the statistics of all subpartitions with a specific index within the + * given partition range. It may represent a complete subpartition group or a part of the + * subpartition group, depending on the partition range. + */ +public class SubpartitionSlice { + + /** The range of partitions that the subpartition slice covers. */ + private final IndexRange partitionRange; + + /** The range of subpartitions that the subpartition slice covers. */ + private final IndexRange subpartitionRange; + + /** The size of the subpartition slice in bytes. */ + private final long dataBytes; + + private SubpartitionSlice( + IndexRange partitionRange, IndexRange subpartitionRange, long dataBytes) { + this.partitionRange = checkNotNull(partitionRange); + this.subpartitionRange = checkNotNull(subpartitionRange); + this.dataBytes = dataBytes; + } + + public long getDataBytes() { + return dataBytes; + } + + public IndexRange getSubpartitionRange() { + return subpartitionRange; + } + + /** + * SubpartitionSlice is used to describe a group of inputs with the same type number which may + * have different numbers of partitions, so we need to use the specific partitions number to get + * the correct partition range. + * + *

Example, given a specific typeNumber with 2 inputs, and partition counts of 3 and 2 + * respectively, if the current SubpartitionSlice's PartitionRange is [1,2], it may need + * adjustment for the second input. the adjustment ensures that the PartitionRange aligns with + * the expected partition count.
+ * -input 0: partition count = 3, valid PartitionRange = [0, 2]
+ * -input 1: partition count = 2, valid PartitionRange = [0, 1]
+ * If the SubpartitionSlice's PartitionRange is [1, 2], it should be corrected to [1, 1] for + * typeNumber 1 to match its partition count. + * + * @param numPartitions the number of partitions + * @return the partition range if the partition range is valid, empty otherwise + */ + public IndexRange getPartitionRange(int numPartitions) { + if (partitionRange.getEndIndex() < numPartitions) { + return partitionRange; + } else if (partitionRange.getStartIndex() < numPartitions + && partitionRange.getEndIndex() >= numPartitions) { + return new IndexRange(partitionRange.getStartIndex(), numPartitions - 1); + } else { + throw new IllegalStateException( + "Invalid partition range " + + partitionRange + + ", number of partitions: " + + numPartitions + + "."); + } + } + + public static SubpartitionSlice createSubpartitionSlice( + IndexRange partitionRange, IndexRange subpartitionRange, long dataBytes) { + return new SubpartitionSlice(partitionRange, subpartitionRange, dataBytes); + } + + public static List createSubpartitionSlicesByMultiPartitionRanges( + List partitionRanges, + IndexRange subpartitionRange, + Map subpartitionBytesByPartition) { + List subpartitionSlices = new ArrayList<>(); + for (IndexRange partitionRange : partitionRanges) { + subpartitionSlices.add( + createSubpartitionSlice( + partitionRange, + subpartitionRange, + calculateDataBytes( + partitionRange, + subpartitionRange, + subpartitionBytesByPartition))); + } + return subpartitionSlices; + } + + private static long calculateDataBytes( + IndexRange partitionRange, + IndexRange subpartitionRange, + Map subpartitionBytesByPartitionIndex) { + return IntStream.rangeClosed(partitionRange.getStartIndex(), partitionRange.getEndIndex()) + .mapToLong( + partitionIndex -> + IntStream.rangeClosed( + subpartitionRange.getStartIndex(), + subpartitionRange.getEndIndex()) + .mapToLong( + subpartitionIndex -> + subpartitionBytesByPartitionIndex + .get(partitionIndex)[ + subpartitionIndex]) + .sum()) + .sum(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java new file mode 100644 index 0000000000000..8e3487ad1178a --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtils.java @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo; +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.scheduler.adaptivebatch.BisectionSearchUtils; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import static org.apache.flink.runtime.executiongraph.IndexRangeUtil.mergeIndexRanges; +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkState; + +/** Utils class for VertexParallelismAndInputInfosDecider. */ +public class VertexParallelismAndInputInfosDeciderUtils { + private static final Logger LOG = + LoggerFactory.getLogger(VertexParallelismAndInputInfosDeciderUtils.class); + + /** + * Adjust the parallelism to the closest legal parallelism and return the computed subpartition + * ranges. + * + * @param currentDataVolumeLimit current data volume limit + * @param currentParallelism current parallelism + * @param minParallelism the min parallelism + * @param maxParallelism the max parallelism + * @param minLimit the minimum data volume limit + * @param maxLimit the maximum data volume limit + * @param parallelismComputer a function to compute the parallelism according to the data volume + * limit + * @param subpartitionRangesComputer a function to compute the subpartition ranges according to + * the data volume limit + * @return the computed subpartition ranges or {@link Optional#empty()} if we can't find any + * legal parallelism + */ + public static Optional> adjustToClosestLegalParallelism( + long currentDataVolumeLimit, + int currentParallelism, + int minParallelism, + int maxParallelism, + long minLimit, + long maxLimit, + Function parallelismComputer, + Function> subpartitionRangesComputer) { + long adjustedDataVolumeLimit = currentDataVolumeLimit; + if (currentParallelism < minParallelism) { + // Current parallelism is smaller than the user-specified lower-limit of parallelism , + // we need to adjust it to the closest/minimum possible legal parallelism. That is, we + // need to find the maximum legal dataVolumeLimit. + adjustedDataVolumeLimit = + BisectionSearchUtils.findMaxLegalValue( + value -> parallelismComputer.apply(value) >= minParallelism, + minLimit, + currentDataVolumeLimit); + + // When we find the minimum possible legal parallelism, the dataVolumeLimit that can + // lead to this parallelism may be a range, and we need to find the minimum value of + // this range to make the data distribution as even as possible (the smaller the + // dataVolumeLimit, the more even the distribution) + final long minPossibleLegalParallelism = + parallelismComputer.apply(adjustedDataVolumeLimit); + adjustedDataVolumeLimit = + BisectionSearchUtils.findMinLegalValue( + value -> + parallelismComputer.apply(value) == minPossibleLegalParallelism, + minLimit, + adjustedDataVolumeLimit); + + } else if (currentParallelism > maxParallelism) { + // Current parallelism is larger than the user-specified upper-limit of parallelism , + // we need to adjust it to the closest/maximum possible legal parallelism. That is, we + // need to find the minimum legal dataVolumeLimit. + adjustedDataVolumeLimit = + BisectionSearchUtils.findMinLegalValue( + value -> parallelismComputer.apply(value) <= maxParallelism, + currentDataVolumeLimit, + maxLimit); + } + + int adjustedParallelism = parallelismComputer.apply(adjustedDataVolumeLimit); + if (isLegalParallelism(adjustedParallelism, minParallelism, maxParallelism)) { + return Optional.of(subpartitionRangesComputer.apply(adjustedDataVolumeLimit)); + } else { + return Optional.empty(); + } + } + + /** + * Computes the Cartesian product of a list of lists. + * + *

The Cartesian product is a set of all possible combinations formed by picking one element + * from each list. For example, given input lists [[1, 2], [3, 4]], the result will be [[1, 3], + * [1, 4], [2, 3], [2, 4]]. + * + *

Note: If the input list is empty or contains an empty list, the result will be an empty + * list. + * + * @param the type of elements in the lists + * @param lists a list of lists for which the Cartesian product is to be computed + * @return a list of lists representing the Cartesian product, where each inner list is a + * combination + */ + public static List> cartesianProduct(List> lists) { + List> resultLists = new ArrayList<>(); + if (lists.isEmpty()) { + resultLists.add(new ArrayList<>()); + return resultLists; + } else { + List firstList = lists.get(0); + List> remainingLists = cartesianProduct(lists.subList(1, lists.size())); + for (T condition : firstList) { + for (List remainingList : remainingLists) { + ArrayList resultList = new ArrayList<>(); + resultList.add(condition); + resultList.addAll(remainingList); + resultLists.add(resultList); + } + } + } + return resultLists; + } + + /** + * Calculates the median of a given array of long integers. If the calculated median is less + * than 1, it returns 1 instead. + * + * @param nums an array of long integers for which to calculate the median. + * @return the median value, which will be at least 1. + */ + public static long median(long[] nums) { + int len = nums.length; + long[] sortedNums = LongStream.of(nums).sorted().toArray(); + if (len % 2 == 0) { + return Math.max((sortedNums[len / 2] + sortedNums[len / 2 - 1]) / 2, 1L); + } else { + return Math.max(sortedNums[len / 2], 1L); + } + } + + /** + * Computes the skew threshold based on the given media size and skewed factor. + * + *

The skew threshold is calculated as the product of the media size and the skewed factor. + * To ensure that the computed threshold does not fall below a specified default value, the + * method uses {@link Math#max} to return the largest of the calculated threshold and the + * default threshold. + * + * @param medianSize the size of the median + * @param skewedFactor a factor indicating the degree of skewness + * @param defaultSkewedThreshold the default threshold to be used if the calculated threshold is + * less than this value + * @return the computed skew threshold, which is guaranteed to be at least the default skewed + * threshold. + */ + public static long computeSkewThreshold( + long medianSize, double skewedFactor, long defaultSkewedThreshold) { + return (long) Math.max(medianSize * skewedFactor, defaultSkewedThreshold); + } + + /** + * Computes the target data size for each task based on the sizes of non-skewed subpartitions. + * + *

The target size is determined as the average size of non-skewed subpartitions and ensures + * that the target size is at least equal to the specified data volume per task. + * + * @param subpartitionBytes an array representing the data size of each subpartition + * @param skewedThreshold skewed threshold in bytes + * @param dataVolumePerTask the amount of data that should be allocated per task + * @return the computed target size for each task, which is the maximum between the average size + * of non-skewed subpartitions and data volume per task. + */ + public static long computeTargetSize( + long[] subpartitionBytes, long skewedThreshold, long dataVolumePerTask) { + long[] nonSkewPartitions = + LongStream.of(subpartitionBytes).filter(v -> v <= skewedThreshold).toArray(); + if (nonSkewPartitions.length == 0) { + return dataVolumePerTask; + } else { + return Math.max( + dataVolumePerTask, + LongStream.of(nonSkewPartitions).sum() / nonSkewPartitions.length); + } + } + + public static List getNonBroadcastInputInfos( + List consumedResults) { + return consumedResults.stream() + .filter(resultInfo -> !resultInfo.isBroadcast()) + .collect(Collectors.toList()); + } + + public static boolean hasSameNumPartitions(List inputInfos) { + Set partitionNums = + inputInfos.stream() + .map(BlockingInputInfo::getNumPartitions) + .collect(Collectors.toSet()); + return partitionNums.size() == 1; + } + + public static int getMaxNumPartitions(List consumedResults) { + checkArgument(!consumedResults.isEmpty()); + return consumedResults.stream() + .mapToInt(BlockingInputInfo::getNumPartitions) + .max() + .getAsInt(); + } + + public static int checkAndGetSubpartitionNum(List consumedResults) { + final Set subpartitionNumSet = + consumedResults.stream() + .flatMap( + resultInfo -> + IntStream.range(0, resultInfo.getNumPartitions()) + .boxed() + .map(resultInfo::getNumSubpartitions)) + .collect(Collectors.toSet()); + // all partitions have the same subpartition num + checkState(subpartitionNumSet.size() == 1); + return subpartitionNumSet.iterator().next(); + } + + public static int checkAndGetSubpartitionNumForAggregatedInputs( + Collection inputInfos) { + final Set subpartitionNumSet = + inputInfos.stream() + .map(AggregatedBlockingInputInfo::getNumSubpartitions) + .collect(Collectors.toSet()); + // all partitions have the same subpartition num + checkState(subpartitionNumSet.size() == 1); + return subpartitionNumSet.iterator().next(); + } + + public static boolean isLegalParallelism( + int parallelism, int minParallelism, int maxParallelism) { + return parallelism >= minParallelism && parallelism <= maxParallelism; + } + + public static boolean checkAndGetIntraCorrelation(List inputInfos) { + Set intraCorrelationSet = + inputInfos.stream() + .map(BlockingInputInfo::isIntraInputKeyCorrelated) + .collect(Collectors.toSet()); + checkArgument(intraCorrelationSet.size() == 1); + return intraCorrelationSet.iterator().next(); + } + + public static int checkAndGetParallelism(Collection vertexInputInfos) { + final Set parallelismSet = + vertexInputInfos.stream() + .map( + vertexInputInfo -> + vertexInputInfo.getExecutionVertexInputInfos().size()) + .collect(Collectors.toSet()); + checkState(parallelismSet.size() == 1); + return parallelismSet.iterator().next(); + } + + /** + * Attempts to compute the subpartition slice ranges to ensure even distribution of data across + * downstream tasks. + * + *

This method first tries to compute the subpartition slice ranges by evenly distributing + * the data volume. If that fails, it attempts to compute the ranges by evenly distributing the + * number of subpartition slices. + * + * @param minParallelism The minimum parallelism. + * @param maxParallelism The maximum parallelism. + * @param maxDataVolumePerTask The maximum data volume per task. + * @param subpartitionSlicesByTypeNumber A map of lists of subpartition slices grouped by type + * number. + * @return An {@code Optional} containing a list of index ranges representing the subpartition + * slice ranges. Returns an empty {@code Optional} if no suitable ranges can be computed. + */ + public static Optional> tryComputeSubpartitionSliceRange( + int minParallelism, + int maxParallelism, + long maxDataVolumePerTask, + Map> subpartitionSlicesByTypeNumber) { + Optional> subpartitionSliceRanges = + tryComputeSubpartitionSliceRangeEvenlyDistributedData( + minParallelism, + maxParallelism, + maxDataVolumePerTask, + subpartitionSlicesByTypeNumber); + if (subpartitionSliceRanges.isEmpty()) { + LOG.info( + "Failed to compute a legal subpartition slice range that can evenly distribute data amount, " + + "fallback to compute it that can evenly distribute the number of subpartition slices."); + subpartitionSliceRanges = + tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices( + minParallelism, maxParallelism, subpartitionSlicesByTypeNumber); + } + return subpartitionSliceRanges; + } + + public static JobVertexInputInfo createdJobVertexInputInfoForBroadcast( + BlockingInputInfo inputInfo, int parallelism) { + checkArgument(inputInfo.isBroadcast()); + int numPartitions = inputInfo.getNumPartitions(); + List executionVertexInputInfos = new ArrayList<>(); + for (int i = 0; i < parallelism; ++i) { + ExecutionVertexInputInfo executionVertexInputInfo; + if (inputInfo.isSingleSubpartitionContainsAllData()) { + executionVertexInputInfo = + new ExecutionVertexInputInfo( + i, new IndexRange(0, numPartitions - 1), new IndexRange(0, 0)); + } else { + // The partitions of the all-to-all result have the same number of + // subpartitions. So we can use the first partition's subpartition + // number. + executionVertexInputInfo = + new ExecutionVertexInputInfo( + i, + new IndexRange(0, numPartitions - 1), + new IndexRange(0, inputInfo.getNumSubpartitions(0) - 1)); + } + executionVertexInputInfos.add(executionVertexInputInfo); + } + return new JobVertexInputInfo(executionVertexInputInfos); + } + + public static JobVertexInputInfo createdJobVertexInputInfoForNonBroadcast( + BlockingInputInfo inputInfo, + List subpartitionSliceRanges, + List subpartitionSlices) { + checkArgument(!inputInfo.isBroadcast()); + int numPartitions = inputInfo.getNumPartitions(); + List executionVertexInputInfos = new ArrayList<>(); + for (int i = 0; i < subpartitionSliceRanges.size(); ++i) { + IndexRange subpartitionSliceRange = subpartitionSliceRanges.get(i); + // Convert subpartitionSlices to partition range to subpartition range + Map consumedSubpartitionGroups = + computeConsumedSubpartitionGroups( + subpartitionSliceRange, + subpartitionSlices, + numPartitions, + inputInfo.isPointwise()); + executionVertexInputInfos.add( + new ExecutionVertexInputInfo(i, consumedSubpartitionGroups)); + } + return new JobVertexInputInfo(executionVertexInputInfos); + } + + private static Optional> tryComputeSubpartitionSliceRangeEvenlyDistributedData( + int minParallelism, + int maxParallelism, + long maxDataVolumePerTask, + Map> subpartitionSlicesByTypeNumber) { + int subpartitionSlicesSize = + checkAndGetSubpartitionSlicesSize(subpartitionSlicesByTypeNumber); + // Distribute the input data evenly among the downstream tasks and record the + // subpartition slice range for each task. + List subpartitionSliceRanges = + computeSubpartitionSliceRanges( + maxDataVolumePerTask, + subpartitionSlicesSize, + subpartitionSlicesByTypeNumber); + // if the parallelism is not legal, try to adjust to a legal parallelism + if (!isLegalParallelism(subpartitionSliceRanges.size(), minParallelism, maxParallelism)) { + long minBytesSize = maxDataVolumePerTask; + long sumBytesSize = 0; + for (int i = 0; i < subpartitionSlicesSize; ++i) { + long currentBytesSize = 0; + for (List subpartitionSlice : + subpartitionSlicesByTypeNumber.values()) { + currentBytesSize += subpartitionSlice.get(i).getDataBytes(); + } + minBytesSize = Math.min(minBytesSize, currentBytesSize); + sumBytesSize += currentBytesSize; + } + return adjustToClosestLegalParallelism( + maxDataVolumePerTask, + subpartitionSliceRanges.size(), + minParallelism, + maxParallelism, + minBytesSize, + sumBytesSize, + limit -> + computeParallelism( + limit, subpartitionSlicesSize, subpartitionSlicesByTypeNumber), + limit -> + computeSubpartitionSliceRanges( + limit, subpartitionSlicesSize, subpartitionSlicesByTypeNumber)); + } + return Optional.of(subpartitionSliceRanges); + } + + private static Optional> + tryComputeSubpartitionSliceRangeEvenlyDistributedSubpartitionSlices( + int minParallelism, + int maxParallelism, + Map> subpartitionSlicesByTypeNumber) { + int subpartitionSlicesSize = + checkAndGetSubpartitionSlicesSize(subpartitionSlicesByTypeNumber); + if (subpartitionSlicesSize < minParallelism) { + return Optional.empty(); + } + int parallelism = Math.min(subpartitionSlicesSize, maxParallelism); + List subpartitionSliceRanges = new ArrayList<>(); + for (int i = 0; i < parallelism; i++) { + int start = i * subpartitionSlicesSize / parallelism; + int nextStart = (i + 1) * subpartitionSlicesSize / parallelism; + subpartitionSliceRanges.add(new IndexRange(start, nextStart - 1)); + } + checkState(subpartitionSliceRanges.size() == parallelism); + return Optional.of(subpartitionSliceRanges); + } + + /** + * Merge the subpartition slices of the specified range into an index range map, which the key + * is the partition index range and the value is the subpartition range. + * + *

Note: In existing algorithms, the consumed subpartition groups for POINTWISE always ensure + * that there is no overlap in the partition ranges, while for ALL_TO_ALL, the consumed + * subpartition groups always ensure that there is no overlap in the subpartition ranges. For + * example, if a task needs to subscribe to {[0,0]->[0,1] ,[1,1]->[0]} (partition range to + * subpartition range), for POINT WISE it will be: {[0,0]->[0,1], [1,1]->[0,0]}, for ALL_TO-ALL + * it will be: {[0,1]->[0,0], [0,0]->[1,1]}.The result of this method will also follow this + * convention. + * + * @param subpartitionSliceRange the range of subpartition slices to be merged + * @param subpartitionSlices subpartition slices + * @param numPartitions the real number of partitions of input info, use to correct the + * partition range + * @param isPointwise whether the input info is pointwise + * @return a map indicating the ranges that task needs to consume, the key is partition range + * and the value is subpartition range. + */ + private static Map computeConsumedSubpartitionGroups( + IndexRange subpartitionSliceRange, + List subpartitionSlices, + int numPartitions, + boolean isPointwise) { + Map> rangeMap = + new TreeMap<>(Comparator.comparingInt(IndexRange::getStartIndex)); + for (int i = subpartitionSliceRange.getStartIndex(); + i <= subpartitionSliceRange.getEndIndex(); + ++i) { + SubpartitionSlice subpartitionSlice = subpartitionSlices.get(i); + IndexRange keyRange, valueRange; + if (isPointwise) { + keyRange = subpartitionSlice.getPartitionRange(numPartitions); + valueRange = subpartitionSlice.getSubpartitionRange(); + } else { + keyRange = subpartitionSlice.getSubpartitionRange(); + valueRange = subpartitionSlice.getPartitionRange(numPartitions); + } + rangeMap.computeIfAbsent(keyRange, k -> new ArrayList<>()).add(valueRange); + } + + rangeMap = + rangeMap.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + entry -> mergeIndexRanges(entry.getValue()))); + + // reversed the map to merge keys associated with the same value + Map> reversedRangeMap = new HashMap<>(); + for (Map.Entry> entry : rangeMap.entrySet()) { + IndexRange valueRange = entry.getKey(); + for (IndexRange keyRange : entry.getValue()) { + reversedRangeMap.computeIfAbsent(keyRange, k -> new ArrayList<>()).add(valueRange); + } + } + + Map mergedReversedRangeMap = + reversedRangeMap.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + entry -> { + List mergedRange = + mergeIndexRanges(entry.getValue()); + checkState(mergedRange.size() == 1); + return mergedRange.get(0); + })); + + if (isPointwise) { + return reverseIndexRangeMap(mergedReversedRangeMap); + } + + return mergedReversedRangeMap; + } + + /** + * Reassembling subpartition slices into balanced n parts and returning the range of index + * corresponding to each piece of data. Reassembling need to meet the following conditions:
+ * 1. The data size of each piece does not exceed the limit.
+ * 2. The SubpartitionSlice number in each piece is not larger than maxRangeSize. + * + * @param limit the limit of data size + * @param subpartitionGroupSize the number of SubpartitionSlices + * @param subpartitionSlices the subpartition slices to be processed + * @return the range of index corresponding to each piece of data + */ + private static List computeSubpartitionSliceRanges( + long limit, + int subpartitionGroupSize, + Map> subpartitionSlices) { + List subpartitionSliceRanges = new ArrayList<>(); + long accumulatedSize = 0; + int startIndex = 0; + Map> bucketsByTypeNumber = new HashMap<>(); + for (int i = 0; i < subpartitionGroupSize; ++i) { + long currentGroupSize = 0L; + // bytes size after deduplication + long currentGroupSizeDeduplicated = 0L; + for (Map.Entry> entry : + subpartitionSlices.entrySet()) { + Integer typeNumber = entry.getKey(); + SubpartitionSlice subpartitionSlice = entry.getValue().get(i); + Set bucket = + bucketsByTypeNumber.computeIfAbsent(typeNumber, ignored -> new HashSet<>()); + // When the bucket already contains duplicate subpartitionSlices, its size should be + // ignored. + if (!bucket.contains(subpartitionSlice)) { + currentGroupSizeDeduplicated += subpartitionSlice.getDataBytes(); + } + currentGroupSize += subpartitionSlice.getDataBytes(); + } + if (i == startIndex || accumulatedSize + currentGroupSizeDeduplicated <= limit) { + accumulatedSize += currentGroupSizeDeduplicated; + } else { + subpartitionSliceRanges.add(new IndexRange(startIndex, i - 1)); + startIndex = i; + accumulatedSize = currentGroupSize; + bucketsByTypeNumber.clear(); + } + for (Map.Entry> entry : + subpartitionSlices.entrySet()) { + Integer typeNumber = entry.getKey(); + SubpartitionSlice subpartitionSlice = entry.getValue().get(i); + bucketsByTypeNumber + .computeIfAbsent(typeNumber, ignored -> new HashSet<>()) + .add(subpartitionSlice); + } + } + subpartitionSliceRanges.add(new IndexRange(startIndex, subpartitionGroupSize - 1)); + return subpartitionSliceRanges; + } + + /** + * The difference from {@link #computeSubpartitionSliceRanges} is that the calculation here only + * returns the parallelism after dividing base on the given limits. + * + * @param limit the limit of data size + * @param subpartitionSlicesSize the number of SubpartitionSlices + * @param subpartitionSlices the subpartition slices to be processed + * @return the parallelism after dividing + */ + private static int computeParallelism( + long limit, + int subpartitionSlicesSize, + Map> subpartitionSlices) { + int count = 1; + long accumulatedSize = 0; + int startIndex = 0; + Map> bucketsByTypeNumber = new HashMap<>(); + for (int i = 0; i < subpartitionSlicesSize; ++i) { + long currentGroupSize = 0L; + long currentGroupSizeDeduplicated = 0L; + for (Map.Entry> entry : + subpartitionSlices.entrySet()) { + Integer typeNumber = entry.getKey(); + SubpartitionSlice subpartitionSlice = entry.getValue().get(i); + Set bucket = + bucketsByTypeNumber.computeIfAbsent(typeNumber, ignored -> new HashSet<>()); + if (!bucket.contains(subpartitionSlice)) { + currentGroupSizeDeduplicated += subpartitionSlice.getDataBytes(); + } + currentGroupSize += subpartitionSlice.getDataBytes(); + } + if (i == startIndex || accumulatedSize + currentGroupSizeDeduplicated <= limit) { + accumulatedSize += currentGroupSizeDeduplicated; + } else { + ++count; + startIndex = i; + accumulatedSize = currentGroupSize; + bucketsByTypeNumber.clear(); + } + for (Map.Entry> entry : + subpartitionSlices.entrySet()) { + Integer typeNumber = entry.getKey(); + SubpartitionSlice subpartitionSlice = entry.getValue().get(i); + bucketsByTypeNumber + .computeIfAbsent(typeNumber, ignored -> new HashSet<>()) + .add(subpartitionSlice); + } + } + return count; + } + + private static int checkAndGetSubpartitionSlicesSize( + Map> subpartitionSlices) { + Set subpartitionSliceSizes = + subpartitionSlices.values().stream().map(List::size).collect(Collectors.toSet()); + checkArgument(subpartitionSliceSizes.size() == 1); + return subpartitionSliceSizes.iterator().next(); + } + + private static Map reverseIndexRangeMap( + Map indexRangeMap) { + Map reversedRangeMap = new HashMap<>(); + for (Map.Entry entry : indexRangeMap.entrySet()) { + checkState(!reversedRangeMap.containsKey(entry.getValue())); + reversedRangeMap.put(entry.getValue(), entry.getKey()); + } + return reversedRangeMap; + } + + public static long calculateDataVolumePerTaskForInputsGroup( + long globalDataVolumePerTask, + List inputsGroup, + List allInputs) { + return calculateDataVolumePerTaskForInput( + globalDataVolumePerTask, + inputsGroup.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum(), + allInputs.stream().mapToLong(BlockingInputInfo::getNumBytesProduced).sum()); + } + + public static long calculateDataVolumePerTaskForInput( + long globalDataVolumePerTask, long inputsGroupBytes, long totalDataBytes) { + return (long) ((double) inputsGroupBytes / totalDataBytes * globalDataVolumePerTask); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java index ccebd02383407..4e529faba17ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java @@ -105,18 +105,6 @@ void testNonBroadcastBytesCanNotDividedEvenly() { assertThat(parallelism).isEqualTo(9); } - @Test - void testDecideParallelismWithMaxSubpartitionLimitation() { - BlockingResultInfo resultInfo1 = - new TestingBlockingResultInfo(false, false, 1L, 1024, 1024); - BlockingResultInfo resultInfo2 = new TestingBlockingResultInfo(false, false, 1L, 512, 512); - - int parallelism = - createDeciderAndDecideParallelism( - 1, 100, BYTE_256_MB, Arrays.asList(resultInfo1, resultInfo2)); - assertThat(parallelism).isEqualTo(32); - } - @Test void testAllEdgesAllToAll() { AllToAllBlockingResultInfo resultInfo1 = @@ -295,21 +283,20 @@ void testHavePointwiseEdges() { parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo1.getResultId()), Arrays.asList( new IndexRange(0, 1), - new IndexRange(2, 4), - new IndexRange(5, 6), - new IndexRange(7, 9))); - checkPointwiseJobVertexInputInfo( + new IndexRange(2, 5), + new IndexRange(6, 7), + new IndexRange(8, 9))); + checkJobVertexInputInfo( parallelismAndInputInfos.getJobVertexInputInfos().get(resultInfo2.getResultId()), Arrays.asList( - new IndexRange(0, 0), - new IndexRange(0, 0), - new IndexRange(1, 1), - new IndexRange(1, 1)), - Arrays.asList( - new IndexRange(0, 1), - new IndexRange(2, 4), - new IndexRange(0, 1), - new IndexRange(2, 4))); + Map.of(new IndexRange(0, 0), new IndexRange(0, 1)), + Map.of(new IndexRange(0, 0), new IndexRange(2, 3)), + Map.of( + new IndexRange(0, 0), + new IndexRange(4, 4), + new IndexRange(1, 1), + new IndexRange(0, 1)), + Map.of(new IndexRange(1, 1), new IndexRange(2, 4)))); } @Test @@ -335,7 +322,7 @@ void testParallelismAlreadyDecided() { checkAllToAllJobVertexInputInfo( Iterables.getOnlyElement( parallelismAndInputInfos.getJobVertexInputInfos().values()), - Arrays.asList(new IndexRange(0, 2), new IndexRange(3, 5), new IndexRange(6, 9))); + Arrays.asList(new IndexRange(0, 2), new IndexRange(3, 6), new IndexRange(7, 9))); } @Test @@ -382,33 +369,6 @@ void testDynamicSourceParallelismWithUpstreamInputs() { new IndexRange(8, 9))); } - @Test - void testEvenlyDistributeDataWithMaxSubpartitionLimitation() { - long[] subpartitionBytes = new long[1024]; - Arrays.fill(subpartitionBytes, 1L); - AllToAllBlockingResultInfo resultInfo = - new AllToAllBlockingResultInfo( - new IntermediateDataSetID(), 1024, 1024, false, false); - for (int i = 0; i < 1024; ++i) { - resultInfo.recordPartitionInfo(i, new ResultPartitionBytes(subpartitionBytes)); - } - - ParallelismAndInputInfos parallelismAndInputInfos = - createDeciderAndDecideParallelismAndInputInfos( - 1, 100, BYTE_256_MB, Collections.singletonList(resultInfo)); - - assertThat(parallelismAndInputInfos.getParallelism()).isEqualTo(32); - List subpartitionRanges = new ArrayList<>(); - for (int i = 0; i < 32; ++i) { - subpartitionRanges.add(new IndexRange(i * 32, (i + 1) * 32 - 1)); - } - checkAllToAllJobVertexInputInfo( - Iterables.getOnlyElement( - parallelismAndInputInfos.getJobVertexInputInfos().values()), - new IndexRange(0, 1023), - subpartitionRanges); - } - @Test void testComputeSourceParallelismUpperBound() { Configuration configuration = new Configuration(); @@ -416,7 +376,7 @@ void testComputeSourceParallelismUpperBound() { BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, DEFAULT_SOURCE_PARALLELISM); VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider = - DefaultVertexParallelismAndInputInfosDecider.from(MAX_PARALLELISM, configuration); + createDefaultVertexParallelismAndInputInfosDecider(MAX_PARALLELISM, configuration); assertThat( vertexParallelismAndInputInfosDecider.computeSourceParallelismUpperBound( new JobVertexID(), VERTEX_MAX_PARALLELISM)) @@ -427,7 +387,7 @@ void testComputeSourceParallelismUpperBound() { void testComputeSourceParallelismUpperBoundFallback() { Configuration configuration = new Configuration(); VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider = - DefaultVertexParallelismAndInputInfosDecider.from(MAX_PARALLELISM, configuration); + createDefaultVertexParallelismAndInputInfosDecider(MAX_PARALLELISM, configuration); assertThat( vertexParallelismAndInputInfosDecider.computeSourceParallelismUpperBound( new JobVertexID(), VERTEX_MAX_PARALLELISM)) @@ -441,7 +401,7 @@ void testComputeSourceParallelismUpperBoundNotExceedMaxParallelism() { BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, VERTEX_MAX_PARALLELISM * 2); VertexParallelismAndInputInfosDecider vertexParallelismAndInputInfosDecider = - DefaultVertexParallelismAndInputInfosDecider.from(MAX_PARALLELISM, configuration); + createDefaultVertexParallelismAndInputInfosDecider(MAX_PARALLELISM, configuration); assertThat( vertexParallelismAndInputInfosDecider.computeSourceParallelismUpperBound( new JobVertexID(), VERTEX_MAX_PARALLELISM)) @@ -467,16 +427,13 @@ private static void checkAllToAllJobVertexInputInfo( .containsExactlyInAnyOrderElementsOf(executionVertexInputInfos); } - private static void checkPointwiseJobVertexInputInfo( + private static void checkJobVertexInputInfo( JobVertexInputInfo jobVertexInputInfo, - List partitionRanges, - List subpartitionRanges) { - assertThat(partitionRanges).hasSameSizeAs(subpartitionRanges); + List> consumedSubpartitionGroups) { List executionVertexInputInfos = new ArrayList<>(); - for (int i = 0; i < subpartitionRanges.size(); ++i) { + for (int i = 0; i < consumedSubpartitionGroups.size(); ++i) { executionVertexInputInfos.add( - new ExecutionVertexInputInfo( - i, partitionRanges.get(i), subpartitionRanges.get(i))); + new ExecutionVertexInputInfo(i, consumedSubpartitionGroups.get(i))); } assertThat(jobVertexInputInfo.getExecutionVertexInputInfos()) .containsExactlyInAnyOrderElementsOf(executionVertexInputInfos); @@ -504,7 +461,20 @@ static DefaultVertexParallelismAndInputInfosDecider createDecider( BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, defaultSourceParallelism); - return DefaultVertexParallelismAndInputInfosDecider.from(maxParallelism, configuration); + return createDefaultVertexParallelismAndInputInfosDecider(maxParallelism, configuration); + } + + static DefaultVertexParallelismAndInputInfosDecider + createDefaultVertexParallelismAndInputInfosDecider( + int maxParallelism, Configuration configuration) { + return DefaultVertexParallelismAndInputInfosDecider.from( + maxParallelism, + BatchExecutionOptionsInternal.ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_FACTOR + .defaultValue(), + BatchExecutionOptionsInternal.ADAPTIVE_SKEWED_OPTIMIZATION_SKEWED_THRESHOLD + .defaultValue() + .getBytes(), + configuration); } private static int createDeciderAndDecideParallelism(List consumedResults) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexInputInfoComputerTestUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexInputInfoComputerTestUtil.java new file mode 100644 index 0000000000000..ed06685189643 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/VertexInputInfoComputerTestUtil.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo; +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.runtime.executiongraph.IndexRangeUtil.mergeIndexRanges; +import static org.assertj.core.api.Assertions.assertThat; + +/** Utils for vertex input info computer test. */ +public class VertexInputInfoComputerTestUtil { + public static List createBlockingInputInfos( + int typeNumber, + int numInputInfos, + int numPartitions, + int numSubpartitions, + boolean existIntraInputKeyCorrelation, + boolean existInterInputsKeyCorrelation, + int defaultSize, + double skewedFactor, + List skewedPartitionIndex, + List skewedSubpartitionIndex, + boolean isPointwise) { + List blockingInputInfos = new ArrayList<>(); + for (int i = 0; i < numInputInfos; i++) { + Map subpartitionBytesByPartitionIndex = new HashMap<>(); + for (int j = 0; j < numPartitions; j++) { + long[] subpartitionBytes = new long[numSubpartitions]; + for (int k = 0; k < numSubpartitions; k++) { + if (skewedSubpartitionIndex.contains(k) || skewedPartitionIndex.contains(j)) { + subpartitionBytes[k] = (long) (defaultSize * skewedFactor); + } else { + subpartitionBytes[k] = defaultSize; + } + } + subpartitionBytesByPartitionIndex.put(j, subpartitionBytes); + } + BlockingResultInfo resultInfo; + if (isPointwise) { + resultInfo = + new PointwiseBlockingResultInfo( + new IntermediateDataSetID(), + numPartitions, + numSubpartitions, + subpartitionBytesByPartitionIndex); + } else { + resultInfo = + new AllToAllBlockingResultInfo( + new IntermediateDataSetID(), + numPartitions, + numSubpartitions, + false, + subpartitionBytesByPartitionIndex); + } + blockingInputInfos.add( + new BlockingInputInfo( + resultInfo, + typeNumber, + existInterInputsKeyCorrelation, + existIntraInputKeyCorrelation)); + } + return blockingInputInfos; + } + + private static void checkParallelism( + int targetParallelism, + Map vertexInputInfoMap) { + vertexInputInfoMap + .values() + .forEach( + info -> + assertThat(info.getExecutionVertexInputInfos().size()) + .isEqualTo(targetParallelism)); + } + + public static void checkConsumedSubpartitionGroups( + List> targetConsumedSubpartitionGroups, + List inputInfos, + Map vertexInputInfoMap) { + JobVertexInputInfo vertexInputInfo = + checkAndGetJobVertexInputInfo(inputInfos, vertexInputInfoMap); + List executionVertexInputInfos = + vertexInputInfo.getExecutionVertexInputInfos(); + for (int i = 0; i < executionVertexInputInfos.size(); i++) { + assertThat(executionVertexInputInfos.get(i).getConsumedSubpartitionGroups()) + .isEqualTo(targetConsumedSubpartitionGroups.get(i)); + } + } + + public static void checkConsumedDataVolumePerSubtask( + long[] targetConsumedDataVolume, + List inputInfos, + Map vertexInputs) { + long[] consumedDataVolume = new long[targetConsumedDataVolume.length]; + for (BlockingInputInfo inputInfo : inputInfos) { + JobVertexInputInfo vertexInputInfo = vertexInputs.get(inputInfo.getResultId()); + List executionVertexInputInfos = + vertexInputInfo.getExecutionVertexInputInfos(); + for (int i = 0; i < executionVertexInputInfos.size(); ++i) { + ExecutionVertexInputInfo executionVertexInputInfo = + executionVertexInputInfos.get(i); + consumedDataVolume[i] += + executionVertexInputInfo.getConsumedSubpartitionGroups().entrySet().stream() + .mapToLong( + entry -> + inputInfo.getNumBytesProduced( + entry.getKey(), entry.getValue())) + .sum(); + } + } + assertThat(consumedDataVolume).isEqualTo(targetConsumedDataVolume); + } + + private static JobVertexInputInfo checkAndGetJobVertexInputInfo( + List inputInfos, + Map vertexInputInfoMap) { + List vertexInputInfos = + inputInfos.stream() + .map(inputInfo -> vertexInputInfoMap.get(inputInfo.getResultId())) + .collect(Collectors.toList()); + assertThat(vertexInputInfos.size()).isEqualTo(inputInfos.size()); + JobVertexInputInfo baseVertexInputInfo = vertexInputInfos.get(0); + for (int i = 1; i < vertexInputInfos.size(); i++) { + assertThat(vertexInputInfos.get(i)).isEqualTo(baseVertexInputInfo); + } + return baseVertexInputInfo; + } + + public static void checkCorrectnessForNonCorrelatedInput( + Map vertexInputInfoMap, + BlockingInputInfo inputInfo, + int targetParallelism) { + checkParallelism(targetParallelism, vertexInputInfoMap); + Map> consumedPartitionToSubpartitionRanges = new HashMap<>(); + vertexInputInfoMap + .get(inputInfo.getResultId()) + .getExecutionVertexInputInfos() + .forEach( + info -> + info.getConsumedSubpartitionGroups() + .forEach( + (partitionRange, subpartitionRange) -> { + for (int i = partitionRange.getStartIndex(); + i <= partitionRange.getEndIndex(); + ++i) { + consumedPartitionToSubpartitionRanges + .computeIfAbsent( + i, k -> new ArrayList<>()) + .add(subpartitionRange); + } + })); + Set partitionIndex = + IntStream.rangeClosed(0, inputInfo.getNumPartitions() - 1) + .boxed() + .collect(Collectors.toSet()); + IndexRange subpartitionRange = new IndexRange(0, inputInfo.getNumSubpartitions(0) - 1); + assertThat(consumedPartitionToSubpartitionRanges.keySet()).isEqualTo(partitionIndex); + consumedPartitionToSubpartitionRanges + .values() + .forEach( + subpartitionRanges -> { + List mergedRange = mergeIndexRanges(subpartitionRanges); + assertThat(mergedRange.size()).isEqualTo(1); + assertThat(mergedRange.get(0)).isEqualTo(subpartitionRange); + }); + } + + public static void checkCorrectnessForCorrelatedInputs( + Map vertexInputInfoMap, + List inputInfos, + int targetParallelism, + int numSubpartitions) { + checkParallelism(targetParallelism, vertexInputInfoMap); + Map> inputInfosGroupByTypeNumber = + inputInfos.stream() + .collect(Collectors.groupingBy(BlockingInputInfo::getInputTypeNumber)); + + Map> vertexInputInfosGroupByTypeNumber = + inputInfosGroupByTypeNumber.entrySet().stream() + .collect( + Collectors.toMap( + Map.Entry::getKey, + e -> + e.getValue().stream() + .map( + v -> + vertexInputInfoMap.get( + v.getResultId())) + .collect(Collectors.toList()))); + + Map vertexInputInfoToNumPartitionsMap = + inputInfosGroupByTypeNumber.values().stream() + .flatMap(List::stream) + .collect( + Collectors.toMap( + v -> vertexInputInfoMap.get(v.getResultId()), + BlockingInputInfo::getNumPartitions)); + assertThat(vertexInputInfosGroupByTypeNumber.size()).isEqualTo(2); + checkCorrectnessForCorrelatedInputs( + vertexInputInfosGroupByTypeNumber.get(1), + vertexInputInfosGroupByTypeNumber.get(2), + vertexInputInfoToNumPartitionsMap, + numSubpartitions); + } + + private static void checkCorrectnessForCorrelatedInputs( + List infosWithTypeNumber1, + List infosWithTypeNumber2, + Map vertexInputInfoToNumPartitionsMap, + int numSubpartitions) { + for (JobVertexInputInfo vertexInputInfo : infosWithTypeNumber1) { + for (JobVertexInputInfo jobVertexInputInfo : infosWithTypeNumber2) { + checkCorrectnessForConsumedSubpartitionRanges( + vertexInputInfo, + jobVertexInputInfo, + vertexInputInfoToNumPartitionsMap.get(vertexInputInfo), + vertexInputInfoToNumPartitionsMap.get(jobVertexInputInfo), + numSubpartitions); + } + } + } + + /** + * This method performs the following checks on inputInfo1 and inputInfo2: 1. Whether they + * subscribe to all subpartitions. 2. Whether the data in subpartitions with the same index + * across both inputInfo1 and inputInfo2 is traversed in a Cartesian product manner. + * + * @param inputInfo1 the inputInfo1 + * @param inputInfo2 the inputInfo2 + * @param numPartitions1 the number of partitions for inputInfo1 + * @param numPartitions2 the number of partitions for inputInfo2 + * @param numSubpartitions the number of subpartitions for both inputInfo1 and inputInfo2 + */ + private static void checkCorrectnessForConsumedSubpartitionRanges( + JobVertexInputInfo inputInfo1, + JobVertexInputInfo inputInfo2, + int numPartitions1, + int numPartitions2, + int numSubpartitions) { + assertThat(inputInfo1.getExecutionVertexInputInfos().size()) + .isEqualTo(inputInfo2.getExecutionVertexInputInfos().size()); + // subpartition index of input1 -> partition index of input1 -> partition index ranges of + // input2 + Map>> input1ToInput2 = new HashMap<>(); + for (int i = 0; i < inputInfo1.getExecutionVertexInputInfos().size(); i++) { + Map> subpartitionIndexToPartition1 = + getConsumedSubpartitionIndexToPartitionRanges( + inputInfo1 + .getExecutionVertexInputInfos() + .get(i) + .getConsumedSubpartitionGroups()); + Map> subpartitionIndexToPartition2 = + getConsumedSubpartitionIndexToPartitionRanges( + inputInfo2 + .getExecutionVertexInputInfos() + .get(i) + .getConsumedSubpartitionGroups()); + subpartitionIndexToPartition1.forEach( + (subpartitionIndex, partitionRanges) -> { + assertThat(subpartitionIndexToPartition2.containsKey(subpartitionIndex)) + .isTrue(); + partitionRanges.forEach( + partitionRange -> { + for (int j = partitionRange.getStartIndex(); + j <= partitionRange.getEndIndex(); + ++j) { + input1ToInput2 + .computeIfAbsent( + subpartitionIndex, k -> new HashMap<>()) + .computeIfAbsent(j, k -> new HashSet<>()) + .addAll( + subpartitionIndexToPartition2.get( + subpartitionIndex)); + } + }); + }); + } + Set partitionIndex = + IntStream.rangeClosed(0, numPartitions1 - 1).boxed().collect(Collectors.toSet()); + Set subpartitionIndexSet = + IntStream.rangeClosed(0, numSubpartitions - 1).boxed().collect(Collectors.toSet()); + IndexRange partitionRange2 = new IndexRange(0, numPartitions2 - 1); + assertThat(input1ToInput2.keySet()).isEqualTo(subpartitionIndexSet); + input1ToInput2.forEach( + (subpartitionIndex, input1ToInput2PartitionRanges) -> { + assertThat(input1ToInput2PartitionRanges.keySet()).isEqualTo(partitionIndex); + input1ToInput2PartitionRanges + .values() + .forEach( + partitionRanges -> { + List mergedRange = + mergeIndexRanges(partitionRanges); + assertThat(mergedRange.size()).isEqualTo(1); + assertThat(mergedRange.get(0)).isEqualTo(partitionRange2); + }); + }); + } + + private static Map> getConsumedSubpartitionIndexToPartitionRanges( + Map consumedSubpartitionGroups) { + Map> subpartitionIndexToPartition = new HashMap<>(); + consumedSubpartitionGroups.forEach( + (partitionRange, subpartitionRange) -> { + for (int j = subpartitionRange.getStartIndex(); + j <= subpartitionRange.getEndIndex(); + ++j) { + subpartitionIndexToPartition + .computeIfAbsent(j, key -> new HashSet<>()) + .add(partitionRange); + } + }); + return subpartitionIndexToPartition; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfoTest.java new file mode 100644 index 0000000000000..3dfca1ed388be --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AggregatedBlockingInputInfoTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.AllToAllVertexInputInfoComputerTest.createBlockingInputInfos; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link AggregatedBlockingInputInfo}. */ +public class AggregatedBlockingInputInfoTest { + @Test + void testAggregatedInputWithSameNumPartitions() { + List inputInfos = + createBlockingInputInfos(1, 10, 4, true, true, List.of(1)); + AggregatedBlockingInputInfo aggregatedBlockingInputInfo = + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + 40, 4, 20, inputInfos); + assertThat(aggregatedBlockingInputInfo.getMaxPartitionNum()).isEqualTo(4); + assertThat(aggregatedBlockingInputInfo.getNumSubpartitions()).isEqualTo(3); + assertThat(aggregatedBlockingInputInfo.getTargetSize()).isEqualTo(40); + assertThat(aggregatedBlockingInputInfo.isSplittable()).isFalse(); + assertThat(aggregatedBlockingInputInfo.isSkewedSubpartition(0)).isEqualTo(false); + assertThat(aggregatedBlockingInputInfo.isSkewedSubpartition(1)).isEqualTo(true); + assertThat(aggregatedBlockingInputInfo.isSkewedSubpartition(2)).isEqualTo(false); + assertThat(aggregatedBlockingInputInfo.getAggregatedSubpartitionBytes(0)).isEqualTo(40); + assertThat(aggregatedBlockingInputInfo.getAggregatedSubpartitionBytes(1)).isEqualTo(400); + assertThat(aggregatedBlockingInputInfo.getAggregatedSubpartitionBytes(2)).isEqualTo(40); + assertThat(aggregatedBlockingInputInfo.getSubpartitionBytesByPartition()).isEmpty(); + + // all subpartition bytes larger than skewed threshold + AggregatedBlockingInputInfo aggregatedBlockingInputInfo2 = + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + 30, 4, 20, inputInfos); + assertThat(aggregatedBlockingInputInfo2.getTargetSize()).isEqualTo(40); + + // larger than skewed factor but less than skewed threshold + AggregatedBlockingInputInfo aggregatedBlockingInputInfo3 = + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + 500, 4, 20, inputInfos); + assertThat(aggregatedBlockingInputInfo3.getTargetSize()).isEqualTo(160); + + // larger than skewed threshold but less than skewed factor + AggregatedBlockingInputInfo aggregatedBlockingInputInfo4 = + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + 100, 20, 20, inputInfos); + assertThat(aggregatedBlockingInputInfo4.getTargetSize()).isEqualTo(160); + + List inputInfosWithoutIntraCorrelation = + createBlockingInputInfos(2, 10, 4, false, true, List.of(1)); + AggregatedBlockingInputInfo aggregatedBlockingInputInfo5 = + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + 40, 4, 20, inputInfosWithoutIntraCorrelation); + assertThat(aggregatedBlockingInputInfo5.getSubpartitionBytesByPartition()) + .containsExactlyInAnyOrderEntriesOf( + Map.of( + 0, + new long[] {10, 100, 10}, + 1, + new long[] {10, 100, 10}, + 2, + new long[] {10, 100, 10}, + 3, + new long[] {10, 100, 10})); + } + + @Test + void testAggregatedInputWithDifferentNumPartitions() { + List inputInfos = new ArrayList<>(); + inputInfos.addAll(createBlockingInputInfos(1, 10, 4, false, true, List.of())); + inputInfos.addAll(createBlockingInputInfos(1, 1, 5, false, true, List.of())); + AggregatedBlockingInputInfo aggregatedBlockingInputInfo = + AggregatedBlockingInputInfo.createAggregatedBlockingInputInfo( + 40, 4, 20, inputInfos); + assertThat(aggregatedBlockingInputInfo.getMaxPartitionNum()).isEqualTo(5); + assertThat(aggregatedBlockingInputInfo.getNumSubpartitions()).isEqualTo(3); + assertThat(aggregatedBlockingInputInfo.getSubpartitionBytesByPartition()).isEmpty(); + assertThat(aggregatedBlockingInputInfo.isSplittable()).isFalse(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputerTest.java new file mode 100644 index 0000000000000..314faa76566ae --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/AllToAllVertexInputInfoComputerTest.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; +import org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkConsumedDataVolumePerSubtask; +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkConsumedSubpartitionGroups; +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkCorrectnessForCorrelatedInputs; +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkCorrectnessForNonCorrelatedInput; + +/** Tests for {@link AllToAllVertexInputInfoComputer}. */ +class AllToAllVertexInputInfoComputerTest { + @Test + void testComputeInputsWithIntraInputKeyCorrelation() { + testComputeInputsWithIntraInputKeyCorrelation(1); + testComputeInputsWithIntraInputKeyCorrelation(10); + } + + void testComputeInputsWithIntraInputKeyCorrelation(int numInputInfos) { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List leftInputInfos = + createBlockingInputInfos(1, numInputInfos, 10, true, true, List.of()); + List rightInputInfos = + createBlockingInputInfos(2, numInputInfos, 10, true, true, List.of()); + inputInfos.addAll(leftInputInfos); + inputInfos.addAll(rightInputInfos); + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 10, 1, 10, 10); + + checkCorrectnessForCorrelatedInputs(vertexInputs, inputInfos, 3, 3); + + List> targetConsumedSubpartitionGroups = + List.of( + Map.of(new IndexRange(0, 9), new IndexRange(0, 0)), + Map.of(new IndexRange(0, 9), new IndexRange(1, 1)), + Map.of(new IndexRange(0, 9), new IndexRange(2, 2))); + checkConsumedSubpartitionGroups(targetConsumedSubpartitionGroups, inputInfos, vertexInputs); + + checkConsumedDataVolumePerSubtask( + new long[] {10L * numInputInfos, 10L * numInputInfos, 10L * numInputInfos}, + leftInputInfos, + vertexInputs); + + checkConsumedDataVolumePerSubtask( + new long[] {10L * numInputInfos, 10L * numInputInfos, 10L * numInputInfos}, + rightInputInfos, + vertexInputs); + } + + @Test + void testInputsUnionWithDifferentNumPartitions() { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List leftInputInfos1 = + createBlockingInputInfos(1, 1, 2, true, true, List.of()); + List leftInputInfos2 = + createBlockingInputInfos(1, 1, 3, true, true, List.of()); + List rightInputInfos = + createBlockingInputInfos(2, 1, 2, true, true, List.of()); + inputInfos.addAll(leftInputInfos1); + inputInfos.addAll(leftInputInfos2); + inputInfos.addAll(rightInputInfos); + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 2, 1, 2, 10); + + checkCorrectnessForCorrelatedInputs(vertexInputs, inputInfos, 2, 3); + + List> left1TargetConsumedSubpartitionGroups = + List.of( + Map.of(new IndexRange(0, 1), new IndexRange(0, 1)), + Map.of(new IndexRange(0, 1), new IndexRange(2, 2))); + + List> left2TargetConsumedSubpartitionGroups = + List.of( + Map.of(new IndexRange(0, 2), new IndexRange(0, 1)), + Map.of(new IndexRange(0, 2), new IndexRange(2, 2))); + + List> rightTargetConsumedSubpartitionGroups = + List.of( + Map.of(new IndexRange(0, 1), new IndexRange(0, 1)), + Map.of(new IndexRange(0, 1), new IndexRange(2, 2))); + + checkConsumedSubpartitionGroups( + left1TargetConsumedSubpartitionGroups, leftInputInfos1, vertexInputs); + checkConsumedSubpartitionGroups( + left2TargetConsumedSubpartitionGroups, leftInputInfos2, vertexInputs); + checkConsumedSubpartitionGroups( + rightTargetConsumedSubpartitionGroups, rightInputInfos, vertexInputs); + } + + @Test + void testComputeSkewedInputWithIntraInputKeyCorrelation() { + testComputeSkewedInputWithIntraInputKeyCorrelation(1); + testComputeSkewedInputWithIntraInputKeyCorrelation(10); + } + + void testComputeSkewedInputWithIntraInputKeyCorrelation(int numInputInfos) { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List leftInputInfos = + createBlockingInputInfos(1, numInputInfos, 10, true, true, List.of(0)); + List rightInputInfos = + createBlockingInputInfos(2, numInputInfos, 10, true, true, List.of()); + inputInfos.addAll(leftInputInfos); + inputInfos.addAll(rightInputInfos); + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 10, 1, 10, 10); + + checkCorrectnessForCorrelatedInputs(vertexInputs, inputInfos, 3, 3); + + List> targetConsumedSubpartitionGroups = + List.of( + Map.of(new IndexRange(0, 9), new IndexRange(0, 0)), + Map.of(new IndexRange(0, 9), new IndexRange(1, 1)), + Map.of(new IndexRange(0, 9), new IndexRange(2, 2))); + checkConsumedSubpartitionGroups(targetConsumedSubpartitionGroups, inputInfos, vertexInputs); + + checkConsumedDataVolumePerSubtask( + new long[] {100L * numInputInfos, 10L * numInputInfos, 10L * numInputInfos}, + leftInputInfos, + vertexInputs); + + checkConsumedDataVolumePerSubtask( + new long[] {10L * numInputInfos, 10L * numInputInfos, 10L * numInputInfos}, + rightInputInfos, + vertexInputs); + } + + @Test + void testComputeSkewedInputWithoutIntraInputKeyCorrelation() { + testComputeSkewedInputWithoutIntraInputKeyCorrelation(1); + testComputeSkewedInputWithoutIntraInputKeyCorrelation(10); + } + + void testComputeSkewedInputWithoutIntraInputKeyCorrelation(int numInputInfos) { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List leftInputInfos = + createBlockingInputInfos(1, numInputInfos, 10, false, true, List.of(0)); + List rightInputInfos = + createBlockingInputInfos(2, numInputInfos, 10, true, true, List.of()); + inputInfos.addAll(leftInputInfos); + inputInfos.addAll(rightInputInfos); + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 10, 1, 10, 10); + + checkCorrectnessForCorrelatedInputs(vertexInputs, inputInfos, 7, 3); + checkConsumedDataVolumePerSubtask( + new long[] { + 20L * numInputInfos, + 20L * numInputInfos, + 20L * numInputInfos, + 20L * numInputInfos, + 20L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos + }, + leftInputInfos, + vertexInputs); + checkConsumedDataVolumePerSubtask( + new long[] { + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos + }, + rightInputInfos, + vertexInputs); + } + + @Test + void testComputeMultipleSkewedInputsWithoutIntraInputKeyCorrelation() { + testComputeMultipleSkewedInputsWithoutIntraInputKeyCorrelation(1); + testComputeMultipleSkewedInputsWithoutIntraInputKeyCorrelation(10); + } + + void testComputeMultipleSkewedInputsWithoutIntraInputKeyCorrelation(int numInputInfos) { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List leftInputInfos = + createBlockingInputInfos(1, numInputInfos, 2, false, true, List.of(1)); + List rightInputInfos = + createBlockingInputInfos(2, numInputInfos, 2, false, true, List.of(1)); + inputInfos.addAll(leftInputInfos); + inputInfos.addAll(rightInputInfos); + + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 1, 1, 2, 10); + checkCorrectnessForCorrelatedInputs(vertexInputs, inputInfos, 2, 3); + checkConsumedDataVolumePerSubtask( + new long[] {12L * numInputInfos, 12L * numInputInfos}, + leftInputInfos, + vertexInputs); + checkConsumedDataVolumePerSubtask( + new long[] {22L * numInputInfos, 22L * numInputInfos}, + rightInputInfos, + vertexInputs); + + // with smaller max parallelism + Map vertexInputs2 = + computer.compute(new JobVertexID(), inputInfos, 1, 1, 1, 10); + checkCorrectnessForCorrelatedInputs(vertexInputs2, inputInfos, 1, 3); + checkConsumedDataVolumePerSubtask( + new long[] {24L * numInputInfos}, leftInputInfos, vertexInputs2); + checkConsumedDataVolumePerSubtask( + new long[] {24L * numInputInfos}, rightInputInfos, vertexInputs2); + + // with bigger max parallelism + Map vertexInputs4 = + computer.compute(new JobVertexID(), inputInfos, 1, 1, 5, 10); + checkCorrectnessForCorrelatedInputs(vertexInputs4, inputInfos, 4, 3); + + checkConsumedDataVolumePerSubtask( + new long[] { + 12L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 12L * numInputInfos + }, + leftInputInfos, + vertexInputs4); + checkConsumedDataVolumePerSubtask( + new long[] { + 12L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 12L * numInputInfos + }, + rightInputInfos, + vertexInputs4); + + // with bigger min parallelism + Map vertexInputs5 = + computer.compute(new JobVertexID(), inputInfos, 5, 5, 5, 10); + checkCorrectnessForCorrelatedInputs(vertexInputs5, inputInfos, 5, 3); + + checkConsumedDataVolumePerSubtask( + new long[] { + 2L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 12L * numInputInfos + }, + leftInputInfos, + vertexInputs5); + checkConsumedDataVolumePerSubtask( + new long[] { + 2L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 10L * numInputInfos, + 12L * numInputInfos + }, + rightInputInfos, + vertexInputs5); + } + + @Test + void testComputeSkewedInputsWithDifferentNumPartitions() { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List leftInputInfos = new ArrayList<>(); + leftInputInfos.addAll(createBlockingInputInfos(1, 1, 2, false, true, List.of(1))); + leftInputInfos.addAll(createBlockingInputInfos(1, 1, 3, false, true, List.of(1))); + List rightInputInfos = + createBlockingInputInfos(2, 1, 2, false, true, List.of(1)); + inputInfos.addAll(leftInputInfos); + inputInfos.addAll(rightInputInfos); + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 2, 1, 2, 10); + + checkCorrectnessForCorrelatedInputs(vertexInputs, inputInfos, 2, 3); + checkConsumedDataVolumePerSubtask(new long[] {55L, 55L}, leftInputInfos, vertexInputs); + checkConsumedDataVolumePerSubtask(new long[] {12L, 12L}, rightInputInfos, vertexInputs); + } + + @Test + void testComputeRebalancedWithAndWithoutCorrelations() { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + + List inputInfoWithCorrelations = + createBlockingInputInfos(1, 1, 10, true, false, List.of()); + Map vertexInputs1 = + computer.compute(new JobVertexID(), inputInfoWithCorrelations, 2, 1, 5, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs1, inputInfoWithCorrelations.get(0), 2); + checkConsumedDataVolumePerSubtask( + new long[] {20L, 10L}, inputInfoWithCorrelations, vertexInputs1); + + List inputInfoWithoutCorrelations = + createBlockingInputInfos(1, 1, 10, false, false, List.of()); + Map vertexInputs2 = + computer.compute(new JobVertexID(), inputInfoWithoutCorrelations, 2, 1, 5, 1); + checkCorrectnessForNonCorrelatedInput( + vertexInputs2, inputInfoWithoutCorrelations.get(0), 2); + checkConsumedDataVolumePerSubtask( + new long[] {15L, 15L}, inputInfoWithoutCorrelations, vertexInputs2); + + // with different parallelism + Map vertexInputs3 = + computer.compute(new JobVertexID(), inputInfoWithoutCorrelations, 3, 1, 5, 10); + checkCorrectnessForNonCorrelatedInput( + vertexInputs3, inputInfoWithoutCorrelations.get(0), 3); + checkConsumedDataVolumePerSubtask( + new long[] {10L, 10L, 10L}, inputInfoWithoutCorrelations, vertexInputs3); + } + + @Test + void testComputeInputsWithDifferentCorrelations() { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = new ArrayList<>(); + List rebalancedInputInfos = + createBlockingInputInfos(1, 1, 10, false, false, List.of()); + List normalInputInfos = + createBlockingInputInfos(1, 1, 10, true, true, List.of()); + inputInfos.addAll(rebalancedInputInfos); + inputInfos.addAll(normalInputInfos); + + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 10, 1, 10, 10); + // although normalInputInfos exist inter inputs key correlation, it is not essentially + // correlated with other inputs because it is a single input, so we can use this test method + checkCorrectnessForNonCorrelatedInput(vertexInputs, normalInputInfos.get(0), 3); + checkConsumedDataVolumePerSubtask( + new long[] {10L, 10L, 10L}, normalInputInfos, vertexInputs); + + checkCorrectnessForNonCorrelatedInput(vertexInputs, rebalancedInputInfos.get(0), 3); + checkConsumedDataVolumePerSubtask( + new long[] {10L, 10L, 10L}, rebalancedInputInfos, vertexInputs); + } + + @Test + void testComputeWithLargeDataVolumePerTask() { + AllToAllVertexInputInfoComputer computer = createAllToAllVertexInputInfoComputer(); + List inputInfos = + createBlockingInputInfos(1, 1, 10, true, true, List.of()); + + Map vertexInputs = + computer.compute(new JobVertexID(), inputInfos, 10, 1, 10, 100); + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfos.get(0), 1); + checkConsumedDataVolumePerSubtask(new long[] {30L}, inputInfos, vertexInputs); + } + + public static List createBlockingInputInfos( + int typeNumber, + int numInputInfos, + int numPartitions, + boolean isIntraInputKeyCorrelated, + boolean areInterInputsKeysCorrelated, + List skewedSubpartitionIndex) { + return VertexInputInfoComputerTestUtil.createBlockingInputInfos( + typeNumber, + numInputInfos, + numPartitions, + 3, + isIntraInputKeyCorrelated, + areInterInputsKeysCorrelated, + 1, + 10, + List.of(), + skewedSubpartitionIndex, + false); + } + + private static AllToAllVertexInputInfoComputer createAllToAllVertexInputInfoComputer() { + return new AllToAllVertexInputInfoComputer(4, 10); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputerTest.java new file mode 100644 index 0000000000000..c70225b31ed26 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/PointwiseVertexInputInfoComputerTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; +import org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkConsumedDataVolumePerSubtask; +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkConsumedSubpartitionGroups; +import static org.apache.flink.runtime.scheduler.adaptivebatch.VertexInputInfoComputerTestUtil.checkCorrectnessForNonCorrelatedInput; + +/** Tests for {@link PointwiseVertexInputInfoComputer}. */ +class PointwiseVertexInputInfoComputerTest { + + @Test + void testComputeNormalInput() { + PointwiseVertexInputInfoComputer computer = createPointwiseVertexInputInfoComputer(); + List inputInfos = createBlockingInputInfos(2, List.of(), false); + Map vertexInputs = + computer.compute(inputInfos, 2, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfos.get(0), 2); + checkConsumedDataVolumePerSubtask(new long[] {3L, 3L}, inputInfos, vertexInputs); + + // with different parallelism + Map vertexInputs2 = + computer.compute(inputInfos, 3, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs2, inputInfos.get(0), 3); + checkConsumedDataVolumePerSubtask(new long[] {2L, 2L, 2L}, inputInfos, vertexInputs2); + } + + @Test + void testComputeSkewedInputsWithDifferentSkewedPartitions() { + PointwiseVertexInputInfoComputer computer = createPointwiseVertexInputInfoComputer(); + List inputInfosWithDifferentSkewedPartitions = new ArrayList<>(); + BlockingInputInfo inputInfo1 = createBlockingInputInfo(3, 3, List.of(0), false); + BlockingInputInfo inputInfo2 = createBlockingInputInfo(3, 3, List.of(1), false); + inputInfosWithDifferentSkewedPartitions.add(inputInfo1); + inputInfosWithDifferentSkewedPartitions.add(inputInfo2); + Map vertexInputs = + computer.compute(inputInfosWithDifferentSkewedPartitions, 3, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfo1, 3); + checkConsumedDataVolumePerSubtask( + new long[] {10L, 10L, 16L}, List.of(inputInfo1), vertexInputs); + + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfo2, 3); + checkConsumedDataVolumePerSubtask( + new long[] {13L, 10L, 13L}, List.of(inputInfo2), vertexInputs); + } + + @Test + void testComputeSkewedInputsWithDifferentNumPartitions() { + PointwiseVertexInputInfoComputer computer = createPointwiseVertexInputInfoComputer(); + List inputInfosWithDifferentNumPartitions = new ArrayList<>(); + BlockingInputInfo inputInfo1 = createBlockingInputInfo(3, 3, List.of(1), false); + BlockingInputInfo inputInfo2 = createBlockingInputInfo(2, 3, List.of(1), false); + inputInfosWithDifferentNumPartitions.add(inputInfo1); + inputInfosWithDifferentNumPartitions.add(inputInfo2); + Map vertexInputs = + computer.compute(inputInfosWithDifferentNumPartitions, 3, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfo1, 3); + checkConsumedDataVolumePerSubtask( + new long[] {13L, 10L, 13L}, List.of(inputInfo1), vertexInputs); + + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfo2, 3); + checkConsumedDataVolumePerSubtask( + new long[] {13L, 10L, 10L}, List.of(inputInfo2), vertexInputs); + } + + @Test + void testComputeSkewedInputsWithDifferentNumSubpartitions() { + PointwiseVertexInputInfoComputer computer = createPointwiseVertexInputInfoComputer(); + List inputInfosWithDifferentNumSubpartitions = new ArrayList<>(); + BlockingInputInfo inputInfo1 = createBlockingInputInfo(3, 3, List.of(1), false); + BlockingInputInfo inputInfo2 = createBlockingInputInfo(3, 5, List.of(1), false); + inputInfosWithDifferentNumSubpartitions.add(inputInfo1); + inputInfosWithDifferentNumSubpartitions.add(inputInfo2); + Map vertexInputs = + computer.compute(inputInfosWithDifferentNumSubpartitions, 3, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfo1, 3); + checkConsumedDataVolumePerSubtask( + new long[] {13L, 10L, 13L}, List.of(inputInfo1), vertexInputs); + + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfo2, 3); + checkConsumedDataVolumePerSubtask( + new long[] {25L, 20L, 15L}, List.of(inputInfo2), vertexInputs); + } + + @Test + void testComputeInputWithIntraCorrelation() { + PointwiseVertexInputInfoComputer computer = createPointwiseVertexInputInfoComputer(); + List inputInfos = createBlockingInputInfos(3, List.of(), true); + Map vertexInputs = + computer.compute(inputInfos, 3, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs, inputInfos.get(0), 3); + checkConsumedSubpartitionGroups( + List.of( + Map.of(new IndexRange(0, 0), new IndexRange(0, 2)), + Map.of(new IndexRange(1, 1), new IndexRange(0, 2)), + Map.of(new IndexRange(2, 2), new IndexRange(0, 2))), + inputInfos, + vertexInputs); + + // with different parallelism + Map vertexInputs2 = + computer.compute(inputInfos, 2, 10); + checkCorrectnessForNonCorrelatedInput(vertexInputs2, inputInfos.get(0), 2); + checkConsumedSubpartitionGroups( + List.of( + Map.of(new IndexRange(0, 1), new IndexRange(0, 2)), + Map.of(new IndexRange(2, 2), new IndexRange(0, 2))), + inputInfos, + vertexInputs2); + } + + private static List createBlockingInputInfos( + int numPartitions, + List skewedPartitionIndex, + boolean existIntraInputKeyCorrelation) { + return List.of( + createBlockingInputInfo( + numPartitions, 3, skewedPartitionIndex, existIntraInputKeyCorrelation)); + } + + private static BlockingInputInfo createBlockingInputInfo( + int numPartitions, + int numSubpartitions, + List skewedPartitionIndex, + boolean existIntraInputKeyCorrelation) { + return VertexInputInfoComputerTestUtil.createBlockingInputInfos( + 1, + 1, + numPartitions, + numSubpartitions, + existIntraInputKeyCorrelation, + false, + 1, + 10, + skewedPartitionIndex, + List.of(), + true) + .get(0); + } + + private static PointwiseVertexInputInfoComputer createPointwiseVertexInputInfoComputer() { + return new PointwiseVertexInputInfoComputer(); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSliceTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSliceTest.java new file mode 100644 index 0000000000000..bb7731c371602 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/SubpartitionSliceTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; +import org.apache.flink.runtime.scheduler.adaptivebatch.BlockingInputInfo; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.AllToAllVertexInputInfoComputerTest.createBlockingInputInfos; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link SubpartitionSlice}. */ +public class SubpartitionSliceTest { + @Test + void testCreateSubpartitionSlice() { + SubpartitionSlice subpartitionSlice = + SubpartitionSlice.createSubpartitionSlice( + new IndexRange(0, 2), new IndexRange(0, 3), 10); + assertThat(subpartitionSlice.getSubpartitionRange()).isEqualTo(new IndexRange(0, 3)); + assertThat(subpartitionSlice.getPartitionRange(2)).isEqualTo(new IndexRange(0, 1)); + assertThat(subpartitionSlice.getPartitionRange(3)).isEqualTo(new IndexRange(0, 2)); + assertThat(subpartitionSlice.getPartitionRange(4)).isEqualTo(new IndexRange(0, 2)); + assertThat(subpartitionSlice.getDataBytes()).isEqualTo(10); + } + + @Test + void testCreateSubpartitionSlices() { + // 1,10,1 1,10,1 1,10,1 + BlockingInputInfo inputInfos = + createBlockingInputInfos(1, 1, 3, true, true, List.of(1)).get(0); + + List subpartitionSlices = + SubpartitionSlice.createSubpartitionSlicesByMultiPartitionRanges( + List.of(new IndexRange(0, 0), new IndexRange(1, 1), new IndexRange(2, 2)), + new IndexRange(0, 0), + inputInfos.getSubpartitionBytesByPartitionIndex()); + checkSubpartitionSlices( + subpartitionSlices, + List.of(new IndexRange(0, 0), new IndexRange(1, 1), new IndexRange(2, 2)), + new IndexRange(0, 0), + new long[] {1L, 1L, 1L}, + 3); + + List subpartitionSlices2 = + SubpartitionSlice.createSubpartitionSlicesByMultiPartitionRanges( + List.of(new IndexRange(0, 1), new IndexRange(2, 2)), + new IndexRange(0, 0), + inputInfos.getSubpartitionBytesByPartitionIndex()); + checkSubpartitionSlices( + subpartitionSlices2, + List.of(new IndexRange(0, 1), new IndexRange(2, 2)), + new IndexRange(0, 0), + new long[] {2L, 1L}, + 3); + + List subpartitionSlices3 = + SubpartitionSlice.createSubpartitionSlicesByMultiPartitionRanges( + List.of(new IndexRange(0, 0), new IndexRange(1, 1), new IndexRange(2, 2)), + new IndexRange(1, 1), + inputInfos.getSubpartitionBytesByPartitionIndex()); + checkSubpartitionSlices( + subpartitionSlices3, + List.of(new IndexRange(0, 0), new IndexRange(1, 1), new IndexRange(2, 2)), + new IndexRange(1, 1), + new long[] {10L, 10L, 10L}, + 3); + } + + private void checkSubpartitionSlices( + List subpartitionSlices, + List partitionRanges, + IndexRange subpartitionRange, + long[] dataBytes, + int numPartitions) { + for (int i = 0; i < subpartitionSlices.size(); ++i) { + SubpartitionSlice subpartitionSlice = subpartitionSlices.get(i); + assertThat(subpartitionSlice.getPartitionRange(numPartitions)) + .isEqualTo(partitionRanges.get(i)); + assertThat(subpartitionSlice.getSubpartitionRange()).isEqualTo(subpartitionRange); + assertThat(subpartitionSlice.getDataBytes()).isEqualTo(dataBytes[i]); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtilsTest.java new file mode 100644 index 0000000000000..5fbbc3eace8a4 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/util/VertexParallelismAndInputInfosDeciderUtilsTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch.util; + +import org.apache.flink.runtime.executiongraph.IndexRange; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.cartesianProduct; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.computeSkewThreshold; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.computeTargetSize; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.median; +import static org.apache.flink.runtime.scheduler.adaptivebatch.util.VertexParallelismAndInputInfosDeciderUtils.tryComputeSubpartitionSliceRange; +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link VertexParallelismAndInputInfosDeciderUtils}. */ +class VertexParallelismAndInputInfosDeciderUtilsTest { + @Test + void testCartesianProduct() { + // empty input + List> inputEmpty = List.of(); + List> expectedEmpty = List.of(List.of()); + List> resultEmpty = cartesianProduct(inputEmpty); + assertThat(resultEmpty).isEqualTo(expectedEmpty); + + // two lists + List> inputTwo = Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3, 4)); + List> expectedTwo = + Arrays.asList( + Arrays.asList(1, 3), + Arrays.asList(1, 4), + Arrays.asList(2, 3), + Arrays.asList(2, 4)); + List> resultTwo = cartesianProduct(inputTwo); + assertThat(resultTwo).isEqualTo(expectedTwo); + + // three lists + List> inputThree = + Arrays.asList( + Arrays.asList("A", "B"), Arrays.asList("1", "2"), Arrays.asList("X", "Y")); + List> expectedThree = + Arrays.asList( + Arrays.asList("A", "1", "X"), + Arrays.asList("A", "1", "Y"), + Arrays.asList("A", "2", "X"), + Arrays.asList("A", "2", "Y"), + Arrays.asList("B", "1", "X"), + Arrays.asList("B", "1", "Y"), + Arrays.asList("B", "2", "X"), + Arrays.asList("B", "2", "Y")); + List> resultThree = cartesianProduct(inputThree); + assertThat(resultThree).isEqualTo(expectedThree); + } + + @Test + void testMedian() { + long[] numsOdd = {5, 1, 3}; + long resultOdd = median(numsOdd); + assertThat(resultOdd).isEqualTo(3L); + + long[] numsEven = {7, 3, 9, 1}; + long resultEven = median(numsEven); + assertThat(resultEven).isEqualTo(5L); + + long[] numsSame = {2, 2, 2, 2, 2}; + long resultSame = median(numsSame); + assertThat(resultSame).isEqualTo(2L); + + long[] numsSingle = {8}; + long resultSingle = median(numsSingle); + assertThat(resultSingle).isEqualTo(8L); + + long[] numsEdges = {2, 4}; + long resultEdges = median(numsEdges); + assertThat(resultEdges).isEqualTo(3L); + + long[] numsLessThanOne = {1, 2, 3, 0, 0}; + long resultLessThanOne = median(numsLessThanOne); + assertThat(resultLessThanOne).isEqualTo(1L); + } + + @Test + void computeSkewThresholdTest() { + long mediaSize1 = 100; + double skewedFactor1 = 1.5; + long defaultSkewedThreshold1 = 50; + long result1 = computeSkewThreshold(mediaSize1, skewedFactor1, defaultSkewedThreshold1); + assertThat(result1).isEqualTo(150L); + + // threshold less than default + long mediaSize2 = 40; + double skewedFactor2 = 1.0; + long defaultSkewedThreshold2 = 50; + long result2 = computeSkewThreshold(mediaSize2, skewedFactor2, defaultSkewedThreshold2); + assertThat(result2).isEqualTo(50L); + } + + @Test + void testComputeTargetSize() { + long[] subpartitionBytes1 = {100, 200, 150, 50}; + long skewedThreshold1 = 150; + long dataVolumePerTask1 = 75; + long result1 = computeTargetSize(subpartitionBytes1, skewedThreshold1, dataVolumePerTask1); + assertThat(result1).isEqualTo(100L); + + // with a larger data volume per task + long[] subpartitionBytes2 = {200, 180, 70, 30}; + long skewedThreshold2 = 100; + long dataVolumePerTask2 = 80; + long result2 = computeTargetSize(subpartitionBytes2, skewedThreshold2, dataVolumePerTask2); + assertThat(result2).isEqualTo(80L); + + // No skewed partitions + long[] subpartitionBytes3 = {100, 50, 75}; + long skewedThreshold3 = 200; + long dataVolumePerTask3 = 60; + long result3 = computeTargetSize(subpartitionBytes3, skewedThreshold3, dataVolumePerTask3); + assertThat(result3).isEqualTo(75L); + } + + @Test + void testComputeSubpartitionSliceRange() { + Map> subpartitionSlicesByTypeNumber = + Map.of(1, createSubpartitionSlices(5, new long[] {100, 200, 300, 200, 100})); + + Optional> subpartitionSliceRanges = + tryComputeSubpartitionSliceRange(1, 5, 300, subpartitionSlicesByTypeNumber); + assertThat(subpartitionSliceRanges).isNotEmpty(); + assertThat(subpartitionSliceRanges.get()) + .isEqualTo( + List.of(new IndexRange(0, 1), new IndexRange(2, 2), new IndexRange(3, 4))); + + // test with a big max data volume per task + subpartitionSliceRanges = + tryComputeSubpartitionSliceRange(1, 5, 10000, subpartitionSlicesByTypeNumber); + assertThat(subpartitionSliceRanges).isNotEmpty(); + assertThat(subpartitionSliceRanges.get()).isEqualTo(List.of(new IndexRange(0, 4))); + + // test with a small max data volume per task + subpartitionSliceRanges = + tryComputeSubpartitionSliceRange(1, 5, 100, subpartitionSlicesByTypeNumber); + assertThat(subpartitionSliceRanges).isNotEmpty(); + assertThat(subpartitionSliceRanges.get()) + .isEqualTo( + List.of( + new IndexRange(0, 0), + new IndexRange(1, 1), + new IndexRange(2, 2), + new IndexRange(3, 3), + new IndexRange(4, 4))); + + // test fallback to adjust to the closest parallelism + subpartitionSliceRanges = + tryComputeSubpartitionSliceRange(5, 5, 200, subpartitionSlicesByTypeNumber); + assertThat(subpartitionSliceRanges).isNotEmpty(); + assertThat(subpartitionSliceRanges.get()) + .isEqualTo( + List.of( + new IndexRange(0, 0), + new IndexRange(1, 1), + new IndexRange(2, 2), + new IndexRange(3, 3), + new IndexRange(4, 4))); + + // test fallback to evenly distributed subpartition slices + subpartitionSliceRanges = + tryComputeSubpartitionSliceRange(4, 4, 200, subpartitionSlicesByTypeNumber); + assertThat(subpartitionSliceRanges).isNotEmpty(); + assertThat(subpartitionSliceRanges.get()) + .isEqualTo( + List.of( + new IndexRange(0, 0), + new IndexRange(1, 1), + new IndexRange(2, 2), + new IndexRange(3, 4))); + + // test failed to compute slice range + subpartitionSliceRanges = + tryComputeSubpartitionSliceRange(6, 6, 200, subpartitionSlicesByTypeNumber); + assertThat(subpartitionSliceRanges).isEmpty(); + } + + List createSubpartitionSlices(int numSlices, long[] dataBytesPerSlice) { + List subpartitionSlices = new ArrayList<>(); + for (int i = 0; i < numSlices; ++i) { + subpartitionSlices.add( + SubpartitionSlice.createSubpartitionSlice( + new IndexRange(0, 0), new IndexRange(i, i), dataBytesPerSlice[i])); + } + return subpartitionSlices; + } +}