Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Spark] Restore memory sensitive GBK translation (#33520) #33521

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.beam.runners.core.ReduceFnRunner;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.SystemReduceFn;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
import org.apache.beam.sdk.coders.KvCoder;
Expand Down Expand Up @@ -84,8 +85,10 @@
*
* <p>Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the
* latter case the entire group (iterator) has to be loaded into memory as well. Either way there's
* a risk of OOM errors. When disabling {@link #useCollectList}, a more memory sensitive iterable is
* used that can be traversed just once. Attempting to traverse the iterable again will throw.
* a risk of OOM errors. When enabling {@link
* SparkCommonPipelineOptions#getPreferGroupByKeyToHandleHugeValues()}, a more memory sensitive
* iterable is used that can be traversed just once. Attempting to traverse the iterable again will
* throw.
*
* <ul>
* <li>When using the default global window, window information is dropped and restored after the
Expand All @@ -108,17 +111,10 @@ class GroupByKeyTranslatorBatch<K, V>
private static final List<Expression> GLOBAL_WINDOW_DETAILS =
windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY}));

private boolean useCollectList = true;

GroupByKeyTranslatorBatch() {
super(0.2f);
}

GroupByKeyTranslatorBatch(boolean useCollectList) {
super(0.2f);
this.useCollectList = useCollectList;
}

@Override
public void translate(GroupByKey<K, V> transform, Context cxt) {
WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
Expand All @@ -135,6 +131,10 @@ public void translate(GroupByKey<K, V> transform, Context cxt) {
// In batch we can ignore triggering and allowed lateness parameters
final Dataset<WindowedValue<KV<K, Iterable<V>>>> result;

boolean useCollectList =
!cxt.getOptions()
.as(SparkCommonPipelineOptions.class)
.getPreferGroupByKeyToHandleHugeValues();
if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) {
// Collects all values per key in memory. This might be problematic if there's few keys only
// or some highly skewed distribution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.SerializableMatcher;
Expand All @@ -48,21 +51,37 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.junit.runners.Parameterized;

/** Test class for beam to spark {@link ParDo} translation. */
@RunWith(JUnit4.class)
@RunWith(Parameterized.class)
public class GroupByKeyTest implements Serializable {
@ClassRule public static final SparkSessionRule SESSION = new SparkSessionRule();

@Parameterized.Parameter public boolean preferGroupByKeyToHandleHugeValues;

@Parameterized.Parameters(name = "Test with preferGroupByKeyToHandleHugeValues={0}")
public static Collection<Object[]> preferGroupByKeyToHandleHugeValues() {
return Arrays.asList(new Object[][] {{true}, {false}});
}

@Rule
public transient TestPipeline pipeline =
TestPipeline.fromOptions(SESSION.createPipelineOptions());

@Before
public void updatePipelineOptions() {
pipeline
.getOptions()
.as(SparkCommonPipelineOptions.class)
.setPreferGroupByKeyToHandleHugeValues(preferGroupByKeyToHandleHugeValues);
}

@Test
public void testGroupByKeyPreservesWindowing() {
pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ public interface SparkCommonPipelineOptions

void setEnableSparkMetricSinks(Boolean enableSparkMetricSinks);

@Description(
"When set to true, runner will try to prefer GroupByKey translation which can handle huge values and "
+ "does not require them to fit into memory. This will most likely have performance impact "
+ "for pipelines which does not work with huge values, hence it is disabled by default.")
@Default.Boolean(false)
Boolean getPreferGroupByKeyToHandleHugeValues();

void setPreferGroupByKeyToHandleHugeValues(Boolean preferGroupByKeyToHandleHugeValues);

/**
* Returns the default checkpoint directory of /tmp/${job.name}. For testing purposes only.
* Production applications should use a reliable filesystem such as HDFS/S3/GS.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.metrics.SparkBeamMetricSource;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.GroupByKeyVisitor;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
import org.apache.beam.runners.spark.translation.TransformEvaluator;
Expand Down Expand Up @@ -214,6 +215,9 @@ public SparkPipelineResult run(final Pipeline pipeline) {
// update the cache candidates
updateCacheCandidates(pipeline, translator, evaluationContext);

// update GBK candidates for memory optimized transform
pipeline.traverseTopologically(new GroupByKeyVisitor(translator, evaluationContext));

initAccumulators(pipelineOptions, jsc);
startPipeline =
executorService.submit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public class EvaluationContext {
private AppliedPTransform<?, ?, ?> currentTransform;
private final SparkPCollectionView pviews = new SparkPCollectionView();
private final Map<PCollection, Long> cacheCandidates = new HashMap<>();
private final Map<GroupByKey<?, ?>, String> groupByKeyCandidatesForMemoryOptimizedTranslation =
new HashMap<>();
private final PipelineOptions options;
private final SerializablePipelineOptions serializableOptions;

Expand Down Expand Up @@ -282,6 +284,29 @@ public Map<PCollection, Long> getCacheCandidates() {
return this.cacheCandidates;
}

/**
* Get the map of GBK transforms to their full names, which are candidates for group by key and
* window translation which aims to reduce memory usage.
*
* @return The current {@link Map} of candidates
*/
public Map<GroupByKey<?, ?>, String> getCandidatesForGroupByKeyAndWindowTranslation() {
return this.groupByKeyCandidatesForMemoryOptimizedTranslation;
}

/**
* Returns if given GBK transform can be considered as candidate for group by key and window
* translation aiming to reduce memory usage.
*
* @param transform to evaluate
* @return true if given transform is a candidate; false otherwise
* @param <K> type of GBK key
* @param <V> type of GBK value
*/
public <K, V> boolean isCandidateForGroupByKeyAndWindow(GroupByKey<K, V> transform) {
return groupByKeyCandidatesForMemoryOptimizedTranslation.containsKey(transform);
}

<T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
@SuppressWarnings("unchecked")
BoundedDataset<T> boundedDataset = (BoundedDataset<T>) datasets.get(pcollection);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.util.construction.PTransformTranslation;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;

/** Traverses the pipeline to populate the candidates for group by key. */
public class GroupByKeyVisitor extends Pipeline.PipelineVisitor.Defaults {

protected final EvaluationContext ctxt;
protected final SparkPipelineTranslator translator;
private boolean isInsideCoGBK = false;
private long visitedGroupByKeyTransformsCount = 0;

public GroupByKeyVisitor(
SparkPipelineTranslator translator, EvaluationContext evaluationContext) {
this.ctxt = evaluationContext;
this.translator = translator;
}

@Override
public Pipeline.PipelineVisitor.CompositeBehavior enterCompositeTransform(
TransformHierarchy.Node node) {
if (node.getTransform() != null && node.getTransform() instanceof CoGroupByKey<?>) {
isInsideCoGBK = true;
}
return CompositeBehavior.ENTER_TRANSFORM;
}

@Override
public void leaveCompositeTransform(TransformHierarchy.Node node) {
if (isInsideCoGBK && node.getTransform() instanceof CoGroupByKey<?>) {
isInsideCoGBK = false;
}
}

@Override
public void visitPrimitiveTransform(TransformHierarchy.Node node) {
PTransform<?, ?> transform = node.getTransform();
if (transform != null) {
String urn = PTransformTranslation.urnForTransformOrNull(transform);
if (PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN.equals(urn)) {
visitedGroupByKeyTransformsCount += 1;
if (!isInsideCoGBK) {
ctxt.getCandidatesForGroupByKeyAndWindowTranslation()
.put((GroupByKey<?, ?>) transform, node.getFullName());
}
}
}
}

@VisibleForTesting
long getVisitedGroupByKeyTransformsCount() {
return visitedGroupByKeyTransformsCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,32 +146,44 @@ public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) {

JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKey;
Partitioner partitioner = getPartitioner(context);
// As this is batch, we can ignore triggering and allowed lateness parameters.
if (windowingStrategy.getWindowFn().equals(new GlobalWindows())
&& windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) {
// we can drop the windows and recover them later
groupedByKey =
GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
inRDD, keyCoder, coder.getValueCoder(), partitioner);
JozoVilcek marked this conversation as resolved.
Show resolved Hide resolved
} else if (GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
// we can have a memory sensitive translation for non-merging windows
boolean enableHugeValuesTranslation =
context
.getOptions()
.as(SparkPipelineOptions.class)
.getPreferGroupByKeyToHandleHugeValues();
if (enableHugeValuesTranslation
&& context.isCandidateForGroupByKeyAndWindow(transform)
&& GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
// we prefer memory sensitive translation of GBK which can support large values per
// key and does not require them to fit into memory
groupedByKey =
JozoVilcek marked this conversation as resolved.
Show resolved Hide resolved
GroupNonMergingWindowsFunctions.groupByKeyAndWindow(
inRDD, keyCoder, coder.getValueCoder(), windowingStrategy, partitioner);
} else {
// --- group by key only.
JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupedByKeyOnly =
GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, partitioner);

// --- now group also by window.
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKey =
groupedByKeyOnly.flatMap(
new SparkGroupAlsoByWindowViaOutputBufferFn<>(
windowingStrategy,
new TranslationUtils.InMemoryStateInternalsFactory<>(),
SystemReduceFn.buffering(coder.getValueCoder()),
context.getSerializableOptions()));
// As this is batch, we can ignore triggering and allowed lateness parameters.
if (windowingStrategy.getWindowFn().equals(new GlobalWindows())
&& windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) {

// we can drop the windows and recover them later
groupedByKey =
GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(
inRDD, keyCoder, coder.getValueCoder(), partitioner);
} else {
// --- group by key only.
JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupedByKeyOnly =
GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, partitioner);

// --- now group also by window.
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKey =
groupedByKeyOnly.flatMap(
new SparkGroupAlsoByWindowViaOutputBufferFn<>(
windowingStrategy,
new TranslationUtils.InMemoryStateInternalsFactory<>(),
SystemReduceFn.buffering(coder.getValueCoder()),
context.getSerializableOptions()));
}
}
context.putDataset(transform, new BoundedDataset<>(groupedByKey));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertEquals;

import org.apache.beam.runners.spark.SparkContextRule;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TupleTag;
import org.junit.ClassRule;
import org.junit.Test;

/** Tests of {@link GroupByKeyVisitor}}. */
public class GroupByKeyVisitorTest {

@ClassRule public static SparkContextRule contextRule = new SparkContextRule();

@Test
public void testTraverseShouldPopulateCandidatesIntoEvaluationContext() {
SparkPipelineOptions options = contextRule.createPipelineOptions();
Pipeline pipeline = Pipeline.create(options);
PCollection<KV<Integer, String>> pCollection =
pipeline.apply(Create.of(KV.of(3, "foo"), KV.of(3, "bar")));

pCollection.apply("CandidateGBK_1", Reshuffle.viaRandomKey());
pCollection.apply("CandidateGBK_2", GroupByKey.create());

final TupleTag<String> t1 = new TupleTag<>();
final TupleTag<String> t2 = new TupleTag<>();

KeyedPCollectionTuple.of(t1, pCollection)
.and(t2, pCollection)
.apply("GBK_inside_CoGBK_ignored", CoGroupByKey.create());

EvaluationContext ctxt =
new EvaluationContext(contextRule.getSparkContext(), pipeline, options);
GroupByKeyVisitor visitor = new GroupByKeyVisitor(new TransformTranslator.Translator(), ctxt);
pipeline.traverseTopologically(visitor);

assertEquals(3, visitor.getVisitedGroupByKeyTransformsCount());
assertEquals(2, ctxt.getCandidatesForGroupByKeyAndWindowTranslation().size());
assertThat(
ctxt.getCandidatesForGroupByKeyAndWindowTranslation().values(),
containsInAnyOrder("CandidateGBK_1/Reshuffle/GroupByKey", "CandidateGBK_2"));
}
}
Loading