Skip to content

Commit

Permalink
Improve memory efficiency in FnApiDoFnRunner (#33522)
Browse files Browse the repository at this point in the history
  • Loading branch information
stankiewicz authored Jan 14, 2025
1 parent 6d3c57f commit 93476eb
Show file tree
Hide file tree
Showing 14 changed files with 305 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.RemoteGrpcPort;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings;
Expand Down Expand Up @@ -95,7 +94,7 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(Context context)
context.getPTransformId(),
context.getPTransform(),
context.getProcessBundleInstructionIdSupplier(),
context.getCoders(),
context.getComponents(),
context.getBeamFnStateClient(),
context::addBundleProgressReporter,
consumer);
Expand Down Expand Up @@ -127,7 +126,7 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(Context context)
String pTransformId,
RunnerApi.PTransform grpcReadNode,
Supplier<String> processBundleInstructionIdSupplier,
Map<String, RunnerApi.Coder> coders,
RunnerApi.Components components,
BeamFnStateClient beamFnStateClient,
Consumer<BundleProgressReporter> addBundleProgressReporter,
FnDataReceiver<WindowedValue<OutputT>> consumer)
Expand All @@ -138,13 +137,12 @@ public BeamFnDataReadRunner<OutputT> createRunnerForPTransform(Context context)
this.processBundleInstructionIdSupplier = processBundleInstructionIdSupplier;
this.consumer = consumer;

RehydratedComponents components =
RehydratedComponents.forComponents(Components.newBuilder().putAllCoders(coders).build());
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components);
this.coder =
(Coder<WindowedValue<OutputT>>)
CoderTranslation.fromProto(
coders.get(port.getCoderId()),
components,
components.getCodersMap().get(port.getCoderId()),
rehydratedComponents,
new StateBackedIterableTranslationContext() {
@Override
public Supplier<Cache<?, ?>> getCache() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.beam.fn.harness.state.StateBackedIterable.StateBackedIterableTranslationContext;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.RemoteGrpcPort;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite;
import org.apache.beam.sdk.util.WindowedValue;
Expand Down Expand Up @@ -64,13 +63,11 @@ static class Factory<InputT> implements PTransformRunnerFactory<BeamFnDataWriteR
public BeamFnDataWriteRunner createRunnerForPTransform(Context context) throws IOException {

RemoteGrpcPort port = RemoteGrpcPortWrite.fromPTransform(context.getPTransform()).getPort();
RehydratedComponents components =
RehydratedComponents.forComponents(
Components.newBuilder().putAllCoders(context.getCoders()).build());
RehydratedComponents components = RehydratedComponents.forComponents(context.getComponents());
Coder<WindowedValue<InputT>> coder =
(Coder<WindowedValue<InputT>>)
CoderTranslation.fromProto(
context.getCoders().get(port.getCoderId()),
context.getComponents().getCodersMap().get(port.getCoderId()),
components,
new StateBackedIterableTranslationContext() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,14 @@ public PrecombineRunner<KeyT, InputT, AccumT> createRunnerForPTransform(Context
throws IOException {
// Get objects needed to create the runner.
RehydratedComponents rehydratedComponents =
RehydratedComponents.forComponents(
RunnerApi.Components.newBuilder()
.putAllCoders(context.getCoders())
.putAllWindowingStrategies(context.getWindowingStrategies())
.build());
RehydratedComponents.forComponents(context.getComponents());
String mainInputTag =
Iterables.getOnlyElement(context.getPTransform().getInputsMap().keySet());
RunnerApi.PCollection mainInput =
context.getPCollections().get(context.getPTransform().getInputsOrThrow(mainInputTag));
context
.getComponents()
.getPcollectionsMap()
.get(context.getPTransform().getInputsOrThrow(mainInputTag));

// Input coder may sometimes be WindowedValueCoder depending on runner, instead of the
// expected KvCoder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
context.getCacheTokensSupplier(),
context.getBundleCacheSupplier(),
context.getProcessWideCache(),
context.getPCollections(),
context.getCoders(),
context.getWindowingStrategies(),
context.getComponents(),
context::addStartBundleFunction,
context::addFinishBundleFunction,
context::addResetFunction,
Expand Down Expand Up @@ -354,9 +352,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
Supplier<List<BeamFnApi.ProcessBundleRequest.CacheToken>> cacheTokens,
Supplier<Cache<?, ?>> bundleCache,
Cache<?, ?> processWideCache,
Map<String, PCollection> pCollections,
Map<String, RunnerApi.Coder> coders,
Map<String, RunnerApi.WindowingStrategy> windowingStrategies,
RunnerApi.Components components,
Consumer<ThrowingRunnable> addStartFunction,
Consumer<ThrowingRunnable> addFinishFunction,
Consumer<ThrowingRunnable> addResetFunction,
Expand All @@ -375,13 +371,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
ImmutableMap.builder();
try {
rehydratedComponents =
RehydratedComponents.forComponents(
RunnerApi.Components.newBuilder()
.putAllCoders(coders)
.putAllPcollections(pCollections)
.putAllWindowingStrategies(windowingStrategies)
.build())
.withPipeline(Pipeline.create());
RehydratedComponents.forComponents(components).withPipeline(Pipeline.create());
parDoPayload = ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
Expand All @@ -404,7 +394,8 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
Iterables.getOnlyElement(
Sets.difference(
pTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
PCollection mainInput = pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
PCollection mainInput =
components.getPcollectionsMap().get(pTransform.getInputsOrThrow(mainInputTag));
Coder<?> maybeWindowedValueInputCoder = rehydratedComponents.getCoder(mainInput.getCoderId());
// TODO: Stop passing windowed value coders within PCollections.
if (maybeWindowedValueInputCoder instanceof WindowedValue.WindowedValueCoder) {
Expand All @@ -426,7 +417,8 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
outputCoders = Maps.newHashMap();
for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
TupleTag<?> outputTag = new TupleTag<>(entry.getKey());
RunnerApi.PCollection outputPCollection = pCollections.get(entry.getValue());
RunnerApi.PCollection outputPCollection =
components.getPcollectionsMap().get(entry.getValue());
Coder<?> outputCoder = rehydratedComponents.getCoder(outputPCollection.getCoderId());
if (outputCoder instanceof WindowedValueCoder) {
outputCoder = ((WindowedValueCoder) outputCoder).getValueCoder();
Expand All @@ -443,7 +435,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
String sideInputTag = entry.getKey();
RunnerApi.SideInput sideInput = entry.getValue();
PCollection sideInputPCollection =
pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
components.getPcollectionsMap().get(pTransform.getInputsOrThrow(sideInputTag));
WindowingStrategy sideInputWindowingStrategy =
rehydratedComponents.getWindowingStrategy(
sideInputPCollection.getWindowingStrategyId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ interface Context {
/** The id of the PTransform. */
String getPTransformId();

/**
* An immutable component with mapping from coder id to coder definition, mapping from windowing
* strategy id to windowing strategy definition and mapping from PCollection id to PCollection
* definition.
*/
RunnerApi.Components getComponents();

/** The PTransform definition. */
RunnerApi.PTransform getPTransform();

Expand All @@ -77,15 +84,6 @@ interface Context {
/** A cache that is process wide and persists across bundle boundaries. */
Cache<?, ?> getProcessWideCache();

/** An immutable mapping from PCollection id to PCollection definition. */
Map<String, RunnerApi.PCollection> getPCollections();

/** An immutable mapping from coder id to coder definition. */
Map<String, RunnerApi.Coder> getCoders();

/** An immutable mapping from windowing strategy id to windowing strategy definition. */
Map<String, RunnerApi.WindowingStrategy> getWindowingStrategies();

/** An immutable set of runner capability urns. */
Set<String> getRunnerCapabilities();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,8 @@
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.StandardRunnerProtocols;
import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
import org.apache.beam.runners.core.metrics.ShortIdMap;
import org.apache.beam.sdk.fn.data.BeamFnDataInboundObserver;
Expand Down Expand Up @@ -236,6 +233,7 @@ private void createRunnerAndConsumersForPTransformRecursively(
Supplier<List<CacheToken>> cacheTokens,
Supplier<Cache<?, ?>> bundleCache,
ProcessBundleDescriptor processBundleDescriptor,
RunnerApi.Components components,
SetMultimap<String, String> pCollectionIdsToConsumingPTransforms,
PCollectionConsumerRegistry pCollectionConsumerRegistry,
Set<String> processedPTransformIds,
Expand Down Expand Up @@ -268,6 +266,7 @@ private void createRunnerAndConsumersForPTransformRecursively(
cacheTokens,
bundleCache,
processBundleDescriptor,
components,
pCollectionIdsToConsumingPTransforms,
pCollectionConsumerRegistry,
processedPTransformIds,
Expand Down Expand Up @@ -358,18 +357,8 @@ public Supplier<List<CacheToken>> getCacheTokensSupplier() {
}

@Override
public Map<String, PCollection> getPCollections() {
return processBundleDescriptor.getPcollectionsMap();
}

@Override
public Map<String, Coder> getCoders() {
return processBundleDescriptor.getCodersMap();
}

@Override
public Map<String, WindowingStrategy> getWindowingStrategies() {
return processBundleDescriptor.getWindowingStrategiesMap();
public RunnerApi.Components getComponents() {
return components;
}

@Override
Expand Down Expand Up @@ -867,6 +856,13 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) {
continue;
}

RunnerApi.Components components =
RunnerApi.Components.newBuilder()
.putAllCoders(bundleDescriptor.getCodersMap())
.putAllPcollections(bundleDescriptor.getPcollectionsMap())
.putAllWindowingStrategies(bundleDescriptor.getWindowingStrategiesMap())
.build();

createRunnerAndConsumersForPTransformRecursively(
beamFnStateClient,
beamFnDataClient,
Expand All @@ -876,6 +872,7 @@ public void afterBundleCommit(Instant callbackExpiry, Callback callback) {
bundleProcessor::getCacheTokens,
bundleProcessor::getBundleCache,
bundleDescriptor,
components,
pCollectionIdsToConsumingPTransforms,
pCollectionConsumerRegistry,
processedPTransformIds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,11 @@ public Coder<BoundedWindow> windowCoder() {
.build()
.toByteString()))
.build())
.pCollections(Collections.singletonMap("input", pCollection))
.coders(Collections.singletonMap("coder-id", coder))
.components(
RunnerApi.Components.newBuilder()
.putAllPcollections(Collections.singletonMap("input", pCollection))
.putAllCoders(Collections.singletonMap("coder-id", coder))
.build())
.build();
Collection<WindowedValue<?>> outputs = new ArrayList<>();
context.addPCollectionConsumer("output", outputs::add);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,17 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
PTransformRunnerFactoryTestContext context =
PTransformRunnerFactoryTestContext.builder(INPUT_TRANSFORM_ID, pTransform)
.processBundleInstructionId(DEFAULT_BUNDLE_ID)
.pCollections(
ImmutableMap.of(
localOutputId,
RunnerApi.PCollection.newBuilder().setCoderId(ELEMENT_CODER_SPEC_ID).build()))
.coders(COMPONENTS.getCodersMap())
.windowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.components(
RunnerApi.Components.newBuilder()
.putAllPcollections(
ImmutableMap.of(
localOutputId,
RunnerApi.PCollection.newBuilder()
.setCoderId(ELEMENT_CODER_SPEC_ID)
.build()))
.putAllCoders(COMPONENTS.getCodersMap())
.putAllWindowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.build())
.build();
context.<String>addPCollectionConsumer(localOutputId, outputValues::add);

Expand Down Expand Up @@ -187,12 +192,17 @@ public void testReuseForMultipleBundles() throws Exception {
INPUT_TRANSFORM_ID,
RemoteGrpcPortRead.readFromPort(PORT_SPEC, localOutputId).toPTransform())
.processBundleInstructionIdSupplier(bundleId::get)
.pCollections(
ImmutableMap.of(
localOutputId,
RunnerApi.PCollection.newBuilder().setCoderId(ELEMENT_CODER_SPEC_ID).build()))
.coders(COMPONENTS.getCodersMap())
.windowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.components(
RunnerApi.Components.newBuilder()
.putAllPcollections(
ImmutableMap.of(
localOutputId,
RunnerApi.PCollection.newBuilder()
.setCoderId(ELEMENT_CODER_SPEC_ID)
.build()))
.putAllCoders(COMPONENTS.getCodersMap())
.putAllWindowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.build())
.build();
context.<String>addPCollectionConsumer(localOutputId, outputValues::add);

Expand Down Expand Up @@ -659,12 +669,17 @@ private static BeamFnDataReadRunner<String> createReadRunner(
PTransformRunnerFactoryTestContext context =
PTransformRunnerFactoryTestContext.builder(pTransformId, pTransform)
.processBundleInstructionId(DEFAULT_BUNDLE_ID)
.pCollections(
ImmutableMap.of(
localOutputId,
RunnerApi.PCollection.newBuilder().setCoderId(ELEMENT_CODER_SPEC_ID).build()))
.coders(COMPONENTS.getCodersMap())
.windowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.components(
RunnerApi.Components.newBuilder()
.putAllPcollections(
ImmutableMap.of(
localOutputId,
RunnerApi.PCollection.newBuilder()
.setCoderId(ELEMENT_CODER_SPEC_ID)
.build()))
.putAllCoders(COMPONENTS.getCodersMap())
.putAllWindowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.build())
.build();
context.addPCollectionConsumer(localOutputId, consumer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,15 @@ public void testReuseForMultipleBundles() throws Exception {
.beamFnDataClient(mockBeamFnDataClient)
.processBundleInstructionIdSupplier(bundleId::get)
.outboundAggregators(aggregators)
.pCollections(
ImmutableMap.of(
localInputId,
RunnerApi.PCollection.newBuilder().setCoderId(ELEM_CODER_ID).build()))
.coders(COMPONENTS.getCodersMap())
.windowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.components(
RunnerApi.Components.newBuilder()
.putAllPcollections(
ImmutableMap.of(
localInputId,
RunnerApi.PCollection.newBuilder().setCoderId(ELEM_CODER_ID).build()))
.putAllCoders(COMPONENTS.getCodersMap())
.putAllWindowingStrategies(COMPONENTS.getWindowingStrategiesMap())
.build())
.build();

new BeamFnDataWriteRunner.Factory<String>().createRunnerForPTransform(context);
Expand Down
Loading

0 comments on commit 93476eb

Please sign in to comment.