diff --git a/flink/v1.17/build.gradle b/flink/v1.17/build.gradle index 2ced7a5a5cb7..0278e4dc3b73 100644 --- a/flink/v1.17/build.gradle +++ b/flink/v1.17/build.gradle @@ -66,6 +66,8 @@ project(":iceberg-flink:iceberg-flink-${flinkMajorVersion}") { exclude group: 'org.slf4j' } + implementation libs.datasketches + testImplementation libs.flink117.connector.test.utils testImplementation libs.flink117.core testImplementation libs.flink117.runtime diff --git a/flink/v1.17/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java b/flink/v1.17/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java index 3b2c74fd6ece..a9ad386a5a4a 100644 --- a/flink/v1.17/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java +++ b/flink/v1.17/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java @@ -19,6 +19,7 @@ package org.apache.iceberg.flink.sink.shuffle; import java.nio.charset.StandardCharsets; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.NavigableMap; @@ -28,6 +29,8 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; @@ -67,6 +70,8 @@ public class MapRangePartitionerBenchmark { Types.NestedField.required(9, "name9", Types.StringType.get())); private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + private static final Comparator SORT_ORDER_COMPARTOR = + SortOrderComparators.forSchema(SCHEMA, SORT_ORDER); private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER); private MapRangePartitioner partitioner; @@ -83,10 +88,11 @@ public void setupBenchmark() { mapStatistics.put(sortKey, weight); }); - MapDataStatistics dataStatistics = new MapDataStatistics(mapStatistics); + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(2, mapStatistics, 0.0, SORT_ORDER_COMPARTOR); this.partitioner = new MapRangePartitioner( - SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), dataStatistics, 2); + SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), mapAssignment); List keys = Lists.newArrayList(weights.keySet().iterator()); long[] weightsCDF = new long[keys.size()]; diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java deleted file mode 100644 index 157f04b8b0ed..000000000000 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import java.io.Serializable; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; - -/** - * AggregatedStatistics is used by {@link DataStatisticsCoordinator} to collect {@link - * DataStatistics} from {@link DataStatisticsOperator} subtasks for specific checkpoint. It stores - * the merged {@link DataStatistics} result from all reported subtasks. - */ -class AggregatedStatistics, S> implements Serializable { - - private final long checkpointId; - private final DataStatistics dataStatistics; - - AggregatedStatistics(long checkpoint, TypeSerializer> statisticsSerializer) { - this.checkpointId = checkpoint; - this.dataStatistics = statisticsSerializer.createInstance(); - } - - AggregatedStatistics(long checkpoint, DataStatistics dataStatistics) { - this.checkpointId = checkpoint; - this.dataStatistics = dataStatistics; - } - - long checkpointId() { - return checkpointId; - } - - DataStatistics dataStatistics() { - return dataStatistics; - } - - void mergeDataStatistic(String operatorName, long eventCheckpointId, D eventDataStatistics) { - Preconditions.checkArgument( - checkpointId == eventCheckpointId, - "Received unexpected event from operator %s checkpoint %s. Expected checkpoint %s", - operatorName, - eventCheckpointId, - checkpointId); - dataStatistics.merge(eventDataStatistics); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("checkpointId", checkpointId) - .add("dataStatistics", dataStatistics) - .toString(); - } -} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java index e8ff61dbeb27..338523b7b074 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java @@ -18,116 +18,238 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Map; +import java.util.NavigableMap; import java.util.Set; +import javax.annotation.Nullable; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.datasketches.sampling.ReservoirItemsUnion; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.SortOrder; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress - * {@link AggregatedStatistics} received from {@link DataStatisticsOperator} subtasks for specific - * checkpoint. + * AggregatedStatisticsTracker tracks the statistics aggregation received from {@link + * DataStatisticsOperator} subtasks for every checkpoint. */ -class AggregatedStatisticsTracker, S> { +class AggregatedStatisticsTracker { private static final Logger LOG = LoggerFactory.getLogger(AggregatedStatisticsTracker.class); - private static final double ACCEPT_PARTIAL_AGGR_THRESHOLD = 90; + private final String operatorName; - private final TypeSerializer> statisticsSerializer; private final int parallelism; - private final Set inProgressSubtaskSet; - private volatile AggregatedStatistics inProgressStatistics; + private final TypeSerializer statisticsSerializer; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final int switchToSketchThreshold; + private final NavigableMap aggregationsPerCheckpoint; + + private CompletedStatistics completedStatistics; AggregatedStatisticsTracker( String operatorName, - TypeSerializer> statisticsSerializer, - int parallelism) { + int parallelism, + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, + StatisticsType statisticsType, + int switchToSketchThreshold, + @Nullable CompletedStatistics restoredStatistics) { this.operatorName = operatorName; - this.statisticsSerializer = statisticsSerializer; this.parallelism = parallelism; - this.inProgressSubtaskSet = Sets.newHashSet(); + this.statisticsSerializer = + new DataStatisticsSerializer(new SortKeySerializer(schema, sortOrder)); + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + this.switchToSketchThreshold = switchToSketchThreshold; + this.completedStatistics = restoredStatistics; + + this.aggregationsPerCheckpoint = Maps.newTreeMap(); } - AggregatedStatistics updateAndCheckCompletion( - int subtask, DataStatisticsEvent event) { + CompletedStatistics updateAndCheckCompletion(int subtask, StatisticsEvent event) { long checkpointId = event.checkpointId(); + LOG.debug( + "Handling statistics event from subtask {} of operator {} for checkpoint {}", + subtask, + operatorName, + checkpointId); - if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + if (completedStatistics != null && completedStatistics.checkpointId() > checkpointId) { LOG.info( - "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + "Ignore stale statistics event from operator {} subtask {} for older checkpoint {}. " + + "Was expecting data statistics from checkpoint higher than {}", operatorName, - inProgressStatistics.checkpointId(), - checkpointId); + subtask, + checkpointId, + completedStatistics.checkpointId()); return null; } - AggregatedStatistics completedStatistics = null; - if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { - if ((double) inProgressSubtaskSet.size() / parallelism * 100 - >= ACCEPT_PARTIAL_AGGR_THRESHOLD) { - completedStatistics = inProgressStatistics; - LOG.info( - "Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}. " - + "Complete data statistics aggregation at checkpoint {} as it is more than the threshold of {} percentage", - inProgressSubtaskSet.size(), - parallelism, - operatorName, + Aggregation aggregation = + aggregationsPerCheckpoint.computeIfAbsent( checkpointId, - inProgressStatistics.checkpointId(), - ACCEPT_PARTIAL_AGGR_THRESHOLD); + ignored -> + new Aggregation( + parallelism, + downstreamParallelism, + switchToSketchThreshold, + statisticsType, + StatisticsUtil.collectType(statisticsType, completedStatistics))); + DataStatistics dataStatistics = + StatisticsUtil.deserializeDataStatistics(event.statisticsBytes(), statisticsSerializer); + if (!aggregation.merge(subtask, dataStatistics)) { + LOG.debug( + "Ignore duplicate data statistics from operator {} subtask {} for checkpoint {}.", + operatorName, + subtask, + checkpointId); + } + + if (aggregation.isComplete()) { + this.completedStatistics = aggregation.completedStatistics(checkpointId); + // clean up aggregations up to the completed checkpoint id + aggregationsPerCheckpoint.headMap(checkpointId, true).clear(); + return completedStatistics; + } + + return null; + } + + @VisibleForTesting + NavigableMap aggregationsPerCheckpoint() { + return aggregationsPerCheckpoint; + } + + static class Aggregation { + private static final Logger LOG = LoggerFactory.getLogger(Aggregation.class); + + private final Set subtaskSet; + private final int parallelism; + private final int downstreamParallelism; + private final int switchToSketchThreshold; + private final StatisticsType configuredType; + private StatisticsType currentType; + private Map mapStatistics; + private ReservoirItemsUnion sketchStatistics; + + Aggregation( + int parallelism, + int downstreamParallelism, + int switchToSketchThreshold, + StatisticsType configuredType, + StatisticsType currentType) { + this.subtaskSet = Sets.newHashSet(); + this.parallelism = parallelism; + this.downstreamParallelism = downstreamParallelism; + this.switchToSketchThreshold = switchToSketchThreshold; + this.configuredType = configuredType; + this.currentType = currentType; + + if (currentType == StatisticsType.Map) { + this.mapStatistics = Maps.newHashMap(); + this.sketchStatistics = null; } else { - LOG.info( - "Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}. " - + "Aborting the incomplete aggregation for checkpoint {}", - inProgressSubtaskSet.size(), - parallelism, - operatorName, - checkpointId, - inProgressStatistics.checkpointId()); + this.mapStatistics = null; + this.sketchStatistics = + ReservoirItemsUnion.newInstance( + SketchUtil.determineCoordinatorReservoirSize(downstreamParallelism)); } + } - inProgressStatistics = null; - inProgressSubtaskSet.clear(); + @VisibleForTesting + Set subtaskSet() { + return subtaskSet; } - if (inProgressStatistics == null) { - LOG.info("Starting a new data statistics for checkpoint {}", checkpointId); - inProgressStatistics = new AggregatedStatistics<>(checkpointId, statisticsSerializer); - inProgressSubtaskSet.clear(); + @VisibleForTesting + StatisticsType currentType() { + return currentType; } - if (!inProgressSubtaskSet.add(subtask)) { - LOG.debug( - "Ignore duplicated data statistics from operator {} subtask {} for checkpoint {}.", - operatorName, - subtask, - checkpointId); - } else { - inProgressStatistics.mergeDataStatistic( - operatorName, - event.checkpointId(), - DataStatisticsUtil.deserializeDataStatistics( - event.statisticsBytes(), statisticsSerializer)); + @VisibleForTesting + Map mapStatistics() { + return mapStatistics; } - if (inProgressSubtaskSet.size() == parallelism) { - completedStatistics = inProgressStatistics; - LOG.info( - "Received data statistics from all {} operators {} for checkpoint {}. Return last completed aggregator {}.", - parallelism, - operatorName, - inProgressStatistics.checkpointId(), - completedStatistics.dataStatistics()); - inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1, statisticsSerializer); - inProgressSubtaskSet.clear(); + @VisibleForTesting + ReservoirItemsUnion sketchStatistics() { + return sketchStatistics; } - return completedStatistics; - } + private boolean isComplete() { + return subtaskSet.size() == parallelism; + } - @VisibleForTesting - AggregatedStatistics inProgressStatistics() { - return inProgressStatistics; + /** @return false if duplicate */ + private boolean merge(int subtask, DataStatistics taskStatistics) { + if (subtaskSet.contains(subtask)) { + return false; + } + + subtaskSet.add(subtask); + merge(taskStatistics); + return true; + } + + @SuppressWarnings("unchecked") + private void merge(DataStatistics taskStatistics) { + if (taskStatistics.type() == StatisticsType.Map) { + Map taskMapStats = (Map) taskStatistics.result(); + if (currentType == StatisticsType.Map) { + taskMapStats.forEach((key, count) -> mapStatistics.merge(key, count, Long::sum)); + if (configuredType == StatisticsType.Auto + && mapStatistics.size() > switchToSketchThreshold) { + convertCoordinatorToSketch(); + } + } else { + // convert task stats to sketch first + ReservoirItemsSketch taskSketch = + ReservoirItemsSketch.newInstance( + SketchUtil.determineOperatorReservoirSize(parallelism, downstreamParallelism)); + SketchUtil.convertMapToSketch(taskMapStats, taskSketch::update); + sketchStatistics.update(taskSketch); + } + } else { + ReservoirItemsSketch taskSketch = + (ReservoirItemsSketch) taskStatistics.result(); + if (currentType == StatisticsType.Map) { + // convert global stats to sketch first + convertCoordinatorToSketch(); + } + + sketchStatistics.update(taskSketch); + } + } + + private void convertCoordinatorToSketch() { + this.sketchStatistics = + ReservoirItemsUnion.newInstance( + SketchUtil.determineCoordinatorReservoirSize(downstreamParallelism)); + SketchUtil.convertMapToSketch(mapStatistics, sketchStatistics::update); + this.currentType = StatisticsType.Sketch; + this.mapStatistics = null; + } + + private CompletedStatistics completedStatistics(long checkpointId) { + if (currentType == StatisticsType.Map) { + LOG.info("Completed map statistics aggregation with {} keys", mapStatistics.size()); + return CompletedStatistics.fromKeyFrequency(checkpointId, mapStatistics); + } else { + ReservoirItemsSketch sketch = sketchStatistics.getResult(); + LOG.info( + "Completed sketch statistics aggregation: " + + "reservoir size = {}, number of items seen = {}, number of samples = {}", + sketch.getK(), + sketch.getN(), + sketch.getNumSamples()); + return CompletedStatistics.fromKeySamples(checkpointId, sketch.getSamples()); + } + } } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatistics.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatistics.java new file mode 100644 index 000000000000..c0e228965ddd --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatistics.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.Map; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +/** + * This is what {@link AggregatedStatisticsTracker} returns upon a completed statistics aggregation + * from all subtasks. It contains the raw statistics (Map or reservoir samples). + */ +class CompletedStatistics { + private final long checkpointId; + private final StatisticsType type; + private final Map keyFrequency; + private final SortKey[] keySamples; + + static CompletedStatistics fromKeyFrequency(long checkpointId, Map stats) { + return new CompletedStatistics(checkpointId, StatisticsType.Map, stats, null); + } + + static CompletedStatistics fromKeySamples(long checkpointId, SortKey[] keySamples) { + return new CompletedStatistics(checkpointId, StatisticsType.Sketch, null, keySamples); + } + + CompletedStatistics( + long checkpointId, + StatisticsType type, + Map keyFrequency, + SortKey[] keySamples) { + this.checkpointId = checkpointId; + this.type = type; + this.keyFrequency = keyFrequency; + this.keySamples = keySamples; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("checkpointId", checkpointId) + .add("type", type) + .add("keyFrequency", keyFrequency) + .add("keySamples", keySamples) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof CompletedStatistics)) { + return false; + } + + CompletedStatistics other = (CompletedStatistics) o; + return Objects.equal(checkpointId, other.checkpointId) + && Objects.equal(type, other.type) + && Objects.equal(keyFrequency, other.keyFrequency()) + && Arrays.equals(keySamples, other.keySamples()); + } + + @Override + public int hashCode() { + return Objects.hashCode(checkpointId, type, keyFrequency, keySamples); + } + + long checkpointId() { + return checkpointId; + } + + StatisticsType type() { + return type; + } + + Map keyFrequency() { + return keyFrequency; + } + + SortKey[] keySamples() { + return keySamples; + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatisticsSerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatisticsSerializer.java new file mode 100644 index 000000000000..7f55188e7f8c --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatisticsSerializer.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; + +class CompletedStatisticsSerializer extends TypeSerializer { + private final TypeSerializer sortKeySerializer; + private final EnumSerializer statisticsTypeSerializer; + private final MapSerializer keyFrequencySerializer; + private final ListSerializer keySamplesSerializer; + + CompletedStatisticsSerializer(TypeSerializer sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE); + this.keySamplesSerializer = new ListSerializer<>(sortKeySerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return new CompletedStatisticsSerializer(sortKeySerializer); + } + + @Override + public CompletedStatistics createInstance() { + return CompletedStatistics.fromKeyFrequency(0L, Collections.emptyMap()); + } + + @Override + public CompletedStatistics copy(CompletedStatistics from) { + return new CompletedStatistics( + from.checkpointId(), from.type(), from.keyFrequency(), from.keySamples()); + } + + @Override + public CompletedStatistics copy(CompletedStatistics from, CompletedStatistics reuse) { + // no benefit of reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(CompletedStatistics record, DataOutputView target) throws IOException { + target.writeLong(record.checkpointId()); + statisticsTypeSerializer.serialize(record.type(), target); + if (record.type() == StatisticsType.Map) { + keyFrequencySerializer.serialize(record.keyFrequency(), target); + } else { + keySamplesSerializer.serialize(Arrays.asList(record.keySamples()), target); + } + } + + @Override + public CompletedStatistics deserialize(DataInputView source) throws IOException { + long checkpointId = source.readLong(); + StatisticsType type = statisticsTypeSerializer.deserialize(source); + if (type == StatisticsType.Map) { + Map keyFrequency = keyFrequencySerializer.deserialize(source); + return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency); + } else { + List sortKeys = keySamplesSerializer.deserialize(source); + SortKey[] keySamples = new SortKey[sortKeys.size()]; + keySamples = sortKeys.toArray(keySamples); + return CompletedStatistics.fromKeySamples(checkpointId, keySamples); + } + } + + @Override + public CompletedStatistics deserialize(CompletedStatistics reuse, DataInputView source) + throws IOException { + // not much benefit to reuse + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + CompletedStatisticsSerializer other = (CompletedStatisticsSerializer) obj; + return Objects.equals(sortKeySerializer, other.sortKeySerializer); + } + + @Override + public int hashCode() { + return sortKeySerializer.hashCode(); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new CompletedStatisticsSerializerSnapshot(this); + } + + public static class CompletedStatisticsSerializerSnapshot + extends CompositeTypeSerializerSnapshot { + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public CompletedStatisticsSerializerSnapshot() { + super(CompletedStatisticsSerializer.class); + } + + @SuppressWarnings("checkstyle:RedundantModifier") + public CompletedStatisticsSerializerSnapshot(CompletedStatisticsSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers( + CompletedStatisticsSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.sortKeySerializer}; + } + + @Override + protected CompletedStatisticsSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + SortKeySerializer sortKeySerializer = (SortKeySerializer) nestedSerializers[0]; + return new CompletedStatisticsSerializer(sortKeySerializer); + } + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java index 9d7cf179ab1c..76c59cd5f4b8 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Map; +import org.apache.datasketches.sampling.ReservoirItemsSketch; import org.apache.flink.annotation.Internal; import org.apache.iceberg.SortKey; @@ -29,29 +31,18 @@ * (sketching) can be used. */ @Internal -interface DataStatistics, S> { +interface DataStatistics { + + StatisticsType type(); - /** - * Check if data statistics contains any statistics information. - * - * @return true if data statistics doesn't contain any statistics information - */ boolean isEmpty(); /** Add row sortKey to data statistics. */ void add(SortKey sortKey); /** - * Merge current statistics with other statistics. - * - * @param otherStatistics the statistics to be merged - */ - void merge(D otherStatistics); - - /** - * Get the underline statistics. - * - * @return the underline statistics + * Get the collected statistics. Could be a {@link Map} (low cardinality) or {@link + * ReservoirItemsSketch} (high cardinality) */ - S statistics(); + Object result(); } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java index c8ac79c61bf6..3b21fbae315a 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Comparator; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; @@ -35,6 +36,10 @@ import org.apache.flink.util.Preconditions; import org.apache.flink.util.ThrowableCatchingRunnable; import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Maps; @@ -44,51 +49,86 @@ import org.slf4j.LoggerFactory; /** - * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link - * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all - * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated data - * statistics back to {@link DataStatisticsOperator}. In the end a custom partitioner will - * distribute traffic based on the aggregated data statistics to improve data clustering. + * DataStatisticsCoordinator receives {@link StatisticsEvent} from {@link DataStatisticsOperator} + * every subtask and then merge them together. Once aggregation for all subtasks data statistics + * completes, DataStatisticsCoordinator will send the aggregated data statistics back to {@link + * DataStatisticsOperator}. In the end a custom partitioner will distribute traffic based on the + * aggregated data statistics to improve data clustering. */ @Internal -class DataStatisticsCoordinator, S> implements OperatorCoordinator { +class DataStatisticsCoordinator implements OperatorCoordinator { private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); private final String operatorName; + private final OperatorCoordinator.Context context; + private final Schema schema; + private final SortOrder sortOrder; + private final Comparator comparator; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final double closeFileCostWeightPercentage; + private final ExecutorService coordinatorExecutor; - private final OperatorCoordinator.Context operatorCoordinatorContext; private final SubtaskGateways subtaskGateways; private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; - private final TypeSerializer> statisticsSerializer; - private final transient AggregatedStatisticsTracker aggregatedStatisticsTracker; - private volatile AggregatedStatistics completedStatistics; - private volatile boolean started; + private final TypeSerializer completedStatisticsSerializer; + private final TypeSerializer globalStatisticsSerializer; + + private transient boolean started; + private transient AggregatedStatisticsTracker aggregatedStatisticsTracker; + private transient CompletedStatistics completedStatistics; + private transient GlobalStatistics globalStatistics; DataStatisticsCoordinator( String operatorName, OperatorCoordinator.Context context, - TypeSerializer> statisticsSerializer) { + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, + StatisticsType statisticsType, + double closeFileCostWeightPercentage) { this.operatorName = operatorName; + this.context = context; + this.schema = schema; + this.sortOrder = sortOrder; + this.comparator = SortOrderComparators.forSchema(schema, sortOrder); + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + this.closeFileCostWeightPercentage = closeFileCostWeightPercentage; + this.coordinatorThreadFactory = new CoordinatorExecutorThreadFactory( "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); - this.operatorCoordinatorContext = context; - this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); - this.statisticsSerializer = statisticsSerializer; - this.aggregatedStatisticsTracker = - new AggregatedStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + this.subtaskGateways = new SubtaskGateways(operatorName, context.currentParallelism()); + SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, sortOrder); + this.completedStatisticsSerializer = new CompletedStatisticsSerializer(sortKeySerializer); + this.globalStatisticsSerializer = new GlobalStatisticsSerializer(sortKeySerializer); } @Override public void start() throws Exception { LOG.info("Starting data statistics coordinator: {}.", operatorName); - started = true; + this.started = true; + + // statistics are restored already in resetToCheckpoint() before start() called + this.aggregatedStatisticsTracker = + new AggregatedStatisticsTracker( + operatorName, + context.currentParallelism(), + schema, + sortOrder, + downstreamParallelism, + statisticsType, + SketchUtil.COORDINATOR_SKETCH_SWITCH_THRESHOLD, + completedStatistics); } @Override public void close() throws Exception { coordinatorExecutor.shutdown(); + this.aggregatedStatisticsTracker = null; + this.started = false; LOG.info("Closed data statistics coordinator: {}.", operatorName); } @@ -148,7 +188,7 @@ private void runInCoordinatorThread(ThrowingRunnable action, String a operatorName, actionString, t); - operatorCoordinatorContext.failJob(t); + context.failJob(t); } }); } @@ -157,42 +197,102 @@ private void ensureStarted() { Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); } - private int parallelism() { - return operatorCoordinatorContext.currentParallelism(); - } - - private void handleDataStatisticRequest(int subtask, DataStatisticsEvent event) { - AggregatedStatistics aggregatedStatistics = + private void handleDataStatisticRequest(int subtask, StatisticsEvent event) { + CompletedStatistics maybeCompletedStatistics = aggregatedStatisticsTracker.updateAndCheckCompletion(subtask, event); - if (aggregatedStatistics != null) { - completedStatistics = aggregatedStatistics; - sendDataStatisticsToSubtasks( - completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + if (maybeCompletedStatistics != null) { + // completedStatistics contains the complete samples, which is needed to compute + // the range bounds in globalStatistics if downstreamParallelism changed. + this.completedStatistics = maybeCompletedStatistics; + // globalStatistics only contains assignment calculated based on Map or Sketch statistics + this.globalStatistics = + globalStatistics( + maybeCompletedStatistics, + downstreamParallelism, + comparator, + closeFileCostWeightPercentage); + sendGlobalStatisticsToSubtasks(globalStatistics); + } + } + + private static GlobalStatistics globalStatistics( + CompletedStatistics completedStatistics, + int downstreamParallelism, + Comparator comparator, + double closeFileCostWeightPercentage) { + if (completedStatistics.type() == StatisticsType.Sketch) { + // range bound is a much smaller array compared to the complete samples. + // It helps reduce the amount of data transfer from coordinator to operator subtasks. + return GlobalStatistics.fromRangeBounds( + completedStatistics.checkpointId(), + SketchUtil.rangeBounds( + downstreamParallelism, comparator, completedStatistics.keySamples())); + } else { + return GlobalStatistics.fromMapAssignment( + completedStatistics.checkpointId(), + MapAssignment.fromKeyFrequency( + downstreamParallelism, + completedStatistics.keyFrequency(), + closeFileCostWeightPercentage, + comparator)); } } @SuppressWarnings("FutureReturnValueIgnored") - private void sendDataStatisticsToSubtasks( - long checkpointId, DataStatistics globalDataStatistics) { - callInCoordinatorThread( + private void sendGlobalStatisticsToSubtasks(GlobalStatistics statistics) { + runInCoordinatorThread( () -> { - DataStatisticsEvent dataStatisticsEvent = - DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); - int parallelism = parallelism(); - for (int i = 0; i < parallelism; ++i) { - subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + LOG.info( + "Broadcast latest global statistics from checkpoint {} to all subtasks", + statistics.checkpointId()); + // applyImmediately is set to false so that operator subtasks can + // apply the change at checkpoint boundary + StatisticsEvent statisticsEvent = + StatisticsEvent.createGlobalStatisticsEvent( + statistics, globalStatisticsSerializer, false); + for (int i = 0; i < context.currentParallelism(); ++i) { + // Ignore future return value for potential error (e.g. subtask down). + // Upon restart, subtasks send request to coordinator to refresh statistics + // if there is any difference + subtaskGateways.getSubtaskGateway(i).sendEvent(statisticsEvent); } - - return null; }, String.format( "Failed to send operator %s coordinator global data statistics for checkpoint %d", - operatorName, checkpointId)); + operatorName, statistics.checkpointId())); + } + + @SuppressWarnings("FutureReturnValueIgnored") + private void handleRequestGlobalStatisticsEvent(int subtask, RequestGlobalStatisticsEvent event) { + if (globalStatistics != null) { + runInCoordinatorThread( + () -> { + if (event.signature() != null && event.signature() != globalStatistics.hashCode()) { + LOG.debug( + "Skip responding to statistics request from subtask {}, as hashCode matches or not included in the request", + subtask); + } else { + LOG.info( + "Send latest global statistics from checkpoint {} to subtask {}", + globalStatistics.checkpointId(), + subtask); + StatisticsEvent statisticsEvent = + StatisticsEvent.createGlobalStatisticsEvent( + globalStatistics, globalStatisticsSerializer, true); + subtaskGateways.getSubtaskGateway(subtask).sendEvent(statisticsEvent); + } + }, + String.format( + "Failed to send operator %s coordinator global data statistics to requesting subtask %d for checkpoint %d", + operatorName, subtask, globalStatistics.checkpointId())); + } else { + LOG.info( + "Ignore global statistics request from subtask {} as statistics not available", subtask); + } } @Override - @SuppressWarnings("unchecked") public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEvent event) { runInCoordinatorThread( () -> { @@ -202,8 +302,14 @@ public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEven attemptNumber, operatorName, event); - Preconditions.checkArgument(event instanceof DataStatisticsEvent); - handleDataStatisticRequest(subtask, ((DataStatisticsEvent) event)); + if (event instanceof StatisticsEvent) { + handleDataStatisticRequest(subtask, ((StatisticsEvent) event)); + } else if (event instanceof RequestGlobalStatisticsEvent) { + handleRequestGlobalStatisticsEvent(subtask, (RequestGlobalStatisticsEvent) event); + } else { + throw new IllegalArgumentException( + "Invalid operator event type: " + event.getClass().getCanonicalName()); + } }, String.format( "handling operator event %s from subtask %d (#%d)", @@ -219,8 +325,8 @@ public void checkpointCoordinator(long checkpointId, CompletableFuture r operatorName, checkpointId); resultFuture.complete( - DataStatisticsUtil.serializeAggregatedStatistics( - completedStatistics, statisticsSerializer)); + StatisticsUtil.serializeCompletedStatistics( + completedStatistics, completedStatisticsSerializer)); }, String.format("taking checkpoint %d", checkpointId)); } @@ -229,11 +335,9 @@ public void checkpointCoordinator(long checkpointId, CompletableFuture r public void notifyCheckpointComplete(long checkpointId) {} @Override - public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData) - throws Exception { + public void resetToCheckpoint(long checkpointId, byte[] checkpointData) { Preconditions.checkState( !started, "The coordinator %s can only be reset if it was not yet started", operatorName); - if (checkpointData == null) { LOG.info( "Data statistic coordinator {} has nothing to restore from checkpoint {}", @@ -244,8 +348,13 @@ public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData LOG.info( "Restoring data statistic coordinator {} from checkpoint {}", operatorName, checkpointId); - completedStatistics = - DataStatisticsUtil.deserializeAggregatedStatistics(checkpointData, statisticsSerializer); + this.completedStatistics = + StatisticsUtil.deserializeCompletedStatistics( + checkpointData, completedStatisticsSerializer); + // recompute global statistics in case downstream parallelism changed + this.globalStatistics = + globalStatistics( + completedStatistics, downstreamParallelism, comparator, closeFileCostWeightPercentage); } @Override @@ -269,7 +378,7 @@ public void executionAttemptFailed(int subtask, int attemptNumber, @Nullable Thr runInCoordinatorThread( () -> { LOG.info( - "Unregistering gateway after failure for subtask {} (#{}) of data statistic {}", + "Unregistering gateway after failure for subtask {} (#{}) of data statistics {}", subtask, attemptNumber, operatorName); @@ -295,14 +404,20 @@ public void executionAttemptReady(int subtask, int attemptNumber, SubtaskGateway } @VisibleForTesting - AggregatedStatistics completedStatistics() { + CompletedStatistics completedStatistics() { return completedStatistics; } + @VisibleForTesting + GlobalStatistics globalStatistics() { + return globalStatistics; + } + private static class SubtaskGateways { private final String operatorName; private final Map[] gateways; + @SuppressWarnings("unchecked") private SubtaskGateways(String operatorName, int parallelism) { this.operatorName = operatorName; gateways = new Map[parallelism]; diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java index 47dbfc3cfbe1..9d7d989c298e 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java @@ -19,33 +19,52 @@ package org.apache.iceberg.flink.sink.shuffle; import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; /** * DataStatisticsCoordinatorProvider provides the method to create new {@link * DataStatisticsCoordinator} */ @Internal -public class DataStatisticsCoordinatorProvider, S> - extends RecreateOnResetOperatorCoordinator.Provider { +public class DataStatisticsCoordinatorProvider extends RecreateOnResetOperatorCoordinator.Provider { private final String operatorName; - private final TypeSerializer> statisticsSerializer; + private final Schema schema; + private final SortOrder sortOrder; + private final int downstreamParallelism; + private final StatisticsType type; + private final double closeFileCostWeightPercentage; public DataStatisticsCoordinatorProvider( String operatorName, OperatorID operatorID, - TypeSerializer> statisticsSerializer) { + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, + StatisticsType type, + double closeFileCostWeightPercentage) { super(operatorID); this.operatorName = operatorName; - this.statisticsSerializer = statisticsSerializer; + this.schema = schema; + this.sortOrder = sortOrder; + this.downstreamParallelism = downstreamParallelism; + this.type = type; + this.closeFileCostWeightPercentage = closeFileCostWeightPercentage; } @Override public OperatorCoordinator getCoordinator(OperatorCoordinator.Context context) { - return new DataStatisticsCoordinator<>(operatorName, context, statisticsSerializer); + return new DataStatisticsCoordinator( + operatorName, + context, + schema, + sortOrder, + downstreamParallelism, + type, + closeFileCostWeightPercentage); } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java index 5157a37cf2cd..59c38b239725 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Map; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; @@ -47,9 +48,8 @@ * distribution to downstream subtasks. */ @Internal -class DataStatisticsOperator, S> - extends AbstractStreamOperator> - implements OneInputStreamOperator>, OperatorEventHandler { +public class DataStatisticsOperator extends AbstractStreamOperator + implements OneInputStreamOperator, OperatorEventHandler { private static final long serialVersionUID = 1L; @@ -57,141 +57,209 @@ class DataStatisticsOperator, S> private final RowDataWrapper rowDataWrapper; private final SortKey sortKey; private final OperatorEventGateway operatorEventGateway; - private final TypeSerializer> statisticsSerializer; - private transient volatile DataStatistics localStatistics; - private transient volatile DataStatistics globalStatistics; - private transient ListState> globalStatisticsState; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final TypeSerializer taskStatisticsSerializer; + private final TypeSerializer globalStatisticsSerializer; + + private transient int parallelism; + private transient int subtaskIndex; + private transient ListState globalStatisticsState; + // current statistics type may be different from the config due to possible + // migration from Map statistics to Sketch statistics when high cardinality detected + private transient volatile StatisticsType taskStatisticsType; + private transient volatile DataStatistics localStatistics; + private transient volatile GlobalStatistics globalStatistics; DataStatisticsOperator( String operatorName, Schema schema, SortOrder sortOrder, OperatorEventGateway operatorEventGateway, - TypeSerializer> statisticsSerializer) { + int downstreamParallelism, + StatisticsType statisticsType) { this.operatorName = operatorName; this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); this.sortKey = new SortKey(schema, sortOrder); this.operatorEventGateway = operatorEventGateway; - this.statisticsSerializer = statisticsSerializer; + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + + SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, sortOrder); + this.taskStatisticsSerializer = new DataStatisticsSerializer(sortKeySerializer); + this.globalStatisticsSerializer = new GlobalStatisticsSerializer(sortKeySerializer); } @Override public void initializeState(StateInitializationContext context) throws Exception { - localStatistics = statisticsSerializer.createInstance(); - globalStatisticsState = + this.parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + this.subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + + // Use union state so that new subtasks can also restore global statistics during scale-up. + this.globalStatisticsState = context .getOperatorStateStore() .getUnionListState( - new ListStateDescriptor<>("globalStatisticsState", statisticsSerializer)); + new ListStateDescriptor<>("globalStatisticsState", globalStatisticsSerializer)); if (context.isRestored()) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); if (globalStatisticsState.get() == null || !globalStatisticsState.get().iterator().hasNext()) { - LOG.warn( + LOG.info( "Operator {} subtask {} doesn't have global statistics state to restore", operatorName, subtaskIndex); - globalStatistics = statisticsSerializer.createInstance(); + // If Flink deprecates union state in the future, RequestGlobalStatisticsEvent can be + // leveraged to request global statistics from coordinator if new subtasks (scale-up case) + // has nothing to restore from. } else { + GlobalStatistics restoredStatistics = globalStatisticsState.get().iterator().next(); LOG.info( - "Restoring operator {} global statistics state for subtask {}", - operatorName, - subtaskIndex); - globalStatistics = globalStatisticsState.get().iterator().next(); + "Operator {} subtask {} restored global statistics state", operatorName, subtaskIndex); + this.globalStatistics = restoredStatistics; } - } else { - globalStatistics = statisticsSerializer.createInstance(); + + // Always request for new statistics from coordinator upon task initialization. + // There are a few scenarios this is needed + // 1. downstream writer parallelism changed due to rescale. + // 2. coordinator failed to send the aggregated statistics to subtask + // (e.g. due to subtask failure at the time). + // Records may flow before coordinator can respond. Range partitioner should be + // able to continue to operate with potentially suboptimal behavior (in sketch case). + LOG.info( + "Operator {} subtask {} requests new global statistics from coordinator ", + operatorName, + subtaskIndex); + // coordinator can use the hashCode (if available) in the request event to determine + // if operator already has the latest global statistics and respond can be skipped. + // This makes the handling cheap in most situations. + RequestGlobalStatisticsEvent event = + globalStatistics != null + ? new RequestGlobalStatisticsEvent(globalStatistics.hashCode()) + : new RequestGlobalStatisticsEvent(); + operatorEventGateway.sendEventToCoordinator(event); } + + this.taskStatisticsType = StatisticsUtil.collectType(statisticsType, globalStatistics); + this.localStatistics = + StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, downstreamParallelism); } @Override public void open() throws Exception { - if (!globalStatistics.isEmpty()) { - output.collect( - new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + if (globalStatistics != null) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics))); } } @Override - @SuppressWarnings("unchecked") public void handleOperatorEvent(OperatorEvent event) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); Preconditions.checkArgument( - event instanceof DataStatisticsEvent, + event instanceof StatisticsEvent, String.format( "Operator %s subtask %s received unexpected operator event %s", operatorName, subtaskIndex, event.getClass())); - DataStatisticsEvent statisticsEvent = (DataStatisticsEvent) event; + StatisticsEvent statisticsEvent = (StatisticsEvent) event; LOG.info( - "Operator {} received global data event from coordinator checkpoint {}", + "Operator {} subtask {} received global data event from coordinator checkpoint {}", operatorName, + subtaskIndex, statisticsEvent.checkpointId()); - globalStatistics = - DataStatisticsUtil.deserializeDataStatistics( - statisticsEvent.statisticsBytes(), statisticsSerializer); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + this.globalStatistics = + StatisticsUtil.deserializeGlobalStatistics( + statisticsEvent.statisticsBytes(), globalStatisticsSerializer); + checkStatisticsTypeMigration(); + // if applyImmediately not set, wait until the checkpoint time to switch + if (statisticsEvent.applyImmediately()) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics))); + } } @Override public void processElement(StreamRecord streamRecord) { + // collect data statistics RowData record = streamRecord.getValue(); StructLike struct = rowDataWrapper.wrap(record); sortKey.wrap(struct); localStatistics.add(sortKey); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record))); + + checkStatisticsTypeMigration(); + output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record))); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { long checkpointId = context.getCheckpointId(); - int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info( - "Snapshotting data statistics operator {} for checkpoint {} in subtask {}", + "Operator {} subtask {} snapshotting data statistics for checkpoint {}", operatorName, - checkpointId, - subTaskId); + subtaskIndex, + checkpointId); - // Pass global statistics to partitioners so that all the operators refresh statistics + // Pass global statistics to partitioner so that all the operators refresh statistics // at same checkpoint barrier - if (!globalStatistics.isEmpty()) { - output.collect( - new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + if (globalStatistics != null) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics))); } // Only subtask 0 saves the state so that globalStatisticsState(UnionListState) stores // an exact copy of globalStatistics - if (!globalStatistics.isEmpty() && getRuntimeContext().getIndexOfThisSubtask() == 0) { + if (globalStatistics != null && getRuntimeContext().getIndexOfThisSubtask() == 0) { globalStatisticsState.clear(); LOG.info( - "Saving operator {} global statistics {} to state in subtask {}", - operatorName, - globalStatistics, - subTaskId); + "Operator {} subtask {} saving global statistics to state", operatorName, subtaskIndex); globalStatisticsState.add(globalStatistics); + LOG.debug( + "Operator {} subtask {} saved global statistics to state: {}", + operatorName, + subtaskIndex, + globalStatistics); } // For now, local statistics are sent to coordinator at checkpoint - operatorEventGateway.sendEventToCoordinator( - DataStatisticsEvent.create(checkpointId, localStatistics, statisticsSerializer)); - LOG.debug( - "Subtask {} of operator {} sent local statistics to coordinator at checkpoint{}: {}", - subTaskId, + LOG.info( + "Operator {} Subtask {} sending local statistics to coordinator for checkpoint {}", operatorName, - checkpointId, - localStatistics); + subtaskIndex, + checkpointId); + operatorEventGateway.sendEventToCoordinator( + StatisticsEvent.createTaskStatisticsEvent( + checkpointId, localStatistics, taskStatisticsSerializer)); // Recreate the local statistics - localStatistics = statisticsSerializer.createInstance(); + localStatistics = + StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, downstreamParallelism); + } + + @SuppressWarnings("unchecked") + private void checkStatisticsTypeMigration() { + // only check if the statisticsType config is Auto and localStatistics is currently Map type + if (statisticsType == StatisticsType.Auto && localStatistics.type() == StatisticsType.Map) { + Map mapStatistics = (Map) localStatistics.result(); + // convert if local statistics has cardinality over the threshold or + // if received global statistics is already sketch type + if (mapStatistics.size() > SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD + || (globalStatistics != null && globalStatistics.type() == StatisticsType.Sketch)) { + LOG.info( + "Operator {} subtask {} switched local statistics from Map to Sketch.", + operatorName, + subtaskIndex); + this.taskStatisticsType = StatisticsType.Sketch; + this.localStatistics = + StatisticsUtil.createTaskStatistics( + taskStatisticsType, parallelism, downstreamParallelism); + SketchUtil.convertMapToSketch(mapStatistics, localStatistics::add); + } + } } @VisibleForTesting - DataStatistics localDataStatistics() { + DataStatistics localStatistics() { return localStatistics; } @VisibleForTesting - DataStatistics globalDataStatistics() { + GlobalStatistics globalStatistics() { return globalStatistics; } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java new file mode 100644 index 000000000000..c25481b3c1f2 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +@Internal +class DataStatisticsSerializer extends TypeSerializer { + private final TypeSerializer sortKeySerializer; + private final EnumSerializer statisticsTypeSerializer; + private final MapSerializer mapSerializer; + private final SortKeySketchSerializer sketchSerializer; + + DataStatisticsSerializer(TypeSerializer sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.mapSerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE); + this.sketchSerializer = new SortKeySketchSerializer(sortKeySerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @SuppressWarnings("ReferenceEquality") + @Override + public TypeSerializer duplicate() { + TypeSerializer duplicateSortKeySerializer = sortKeySerializer.duplicate(); + return (duplicateSortKeySerializer == sortKeySerializer) + ? this + : new DataStatisticsSerializer(duplicateSortKeySerializer); + } + + @Override + public DataStatistics createInstance() { + return new MapDataStatistics(); + } + + @SuppressWarnings("unchecked") + @Override + public DataStatistics copy(DataStatistics obj) { + StatisticsType statisticsType = obj.type(); + if (statisticsType == StatisticsType.Map) { + MapDataStatistics from = (MapDataStatistics) obj; + Map fromStats = (Map) from.result(); + Map toStats = Maps.newHashMap(fromStats); + return new MapDataStatistics(toStats); + } else if (statisticsType == StatisticsType.Sketch) { + // because ReservoirItemsSketch doesn't expose enough public methods for cloning, + // this implementation adopted the less efficient serialization and deserialization. + SketchDataStatistics from = (SketchDataStatistics) obj; + ReservoirItemsSketch fromStats = (ReservoirItemsSketch) from.result(); + byte[] bytes = fromStats.toByteArray(sketchSerializer); + Memory memory = Memory.wrap(bytes); + ReservoirItemsSketch toStats = + ReservoirItemsSketch.heapify(memory, sketchSerializer); + return new SketchDataStatistics(toStats); + } else { + throw new IllegalArgumentException("Unsupported data statistics type: " + statisticsType); + } + } + + @Override + public DataStatistics copy(DataStatistics from, DataStatistics reuse) { + // not much benefit to reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @SuppressWarnings("unchecked") + @Override + public void serialize(DataStatistics obj, DataOutputView target) throws IOException { + StatisticsType statisticsType = obj.type(); + statisticsTypeSerializer.serialize(obj.type(), target); + if (statisticsType == StatisticsType.Map) { + Map mapStatistics = (Map) obj.result(); + mapSerializer.serialize(mapStatistics, target); + } else if (statisticsType == StatisticsType.Sketch) { + ReservoirItemsSketch sketch = (ReservoirItemsSketch) obj.result(); + byte[] sketchBytes = sketch.toByteArray(sketchSerializer); + target.writeInt(sketchBytes.length); + target.write(sketchBytes); + } else { + throw new IllegalArgumentException("Unsupported data statistics type: " + statisticsType); + } + } + + @Override + public DataStatistics deserialize(DataInputView source) throws IOException { + StatisticsType statisticsType = statisticsTypeSerializer.deserialize(source); + if (statisticsType == StatisticsType.Map) { + Map mapStatistics = mapSerializer.deserialize(source); + return new MapDataStatistics(mapStatistics); + } else if (statisticsType == StatisticsType.Sketch) { + int numBytes = source.readInt(); + byte[] sketchBytes = new byte[numBytes]; + source.read(sketchBytes); + Memory sketchMemory = Memory.wrap(sketchBytes); + ReservoirItemsSketch sketch = + ReservoirItemsSketch.heapify(sketchMemory, sketchSerializer); + return new SketchDataStatistics(sketch); + } else { + throw new IllegalArgumentException("Unsupported data statistics type: " + statisticsType); + } + } + + @Override + public DataStatistics deserialize(DataStatistics reuse, DataInputView source) throws IOException { + // not much benefit to reuse + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof DataStatisticsSerializer)) { + return false; + } + + DataStatisticsSerializer other = (DataStatisticsSerializer) obj; + return Objects.equals(sortKeySerializer, other.sortKeySerializer); + } + + @Override + public int hashCode() { + return sortKeySerializer.hashCode(); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new DataStatisticsSerializerSnapshot(this); + } + + public static class DataStatisticsSerializerSnapshot + extends CompositeTypeSerializerSnapshot { + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public DataStatisticsSerializerSnapshot() { + super(DataStatisticsSerializer.class); + } + + @SuppressWarnings("checkstyle:RedundantModifier") + public DataStatisticsSerializerSnapshot(DataStatisticsSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers(DataStatisticsSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.sortKeySerializer}; + } + + @Override + protected DataStatisticsSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + SortKeySerializer sortKeySerializer = (SortKeySerializer) nestedSerializers[0]; + return new DataStatisticsSerializer(sortKeySerializer); + } + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java deleted file mode 100644 index 8716cb872d0e..000000000000 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.memory.DataInputDeserializer; -import org.apache.flink.core.memory.DataOutputSerializer; - -/** - * DataStatisticsUtil is the utility to serialize and deserialize {@link DataStatistics} and {@link - * AggregatedStatistics} - */ -class DataStatisticsUtil { - - private DataStatisticsUtil() {} - - static , S> byte[] serializeDataStatistics( - DataStatistics dataStatistics, - TypeSerializer> statisticsSerializer) { - DataOutputSerializer out = new DataOutputSerializer(64); - try { - statisticsSerializer.serialize(dataStatistics, out); - return out.getCopyOfBuffer(); - } catch (IOException e) { - throw new IllegalStateException("Fail to serialize data statistics", e); - } - } - - @SuppressWarnings("unchecked") - static , S> D deserializeDataStatistics( - byte[] bytes, TypeSerializer> statisticsSerializer) { - DataInputDeserializer input = new DataInputDeserializer(bytes, 0, bytes.length); - try { - return (D) statisticsSerializer.deserialize(input); - } catch (IOException e) { - throw new IllegalStateException("Fail to deserialize data statistics", e); - } - } - - static , S> byte[] serializeAggregatedStatistics( - AggregatedStatistics aggregatedStatistics, - TypeSerializer> statisticsSerializer) - throws IOException { - ByteArrayOutputStream bytes = new ByteArrayOutputStream(); - ObjectOutputStream out = new ObjectOutputStream(bytes); - - DataOutputSerializer outSerializer = new DataOutputSerializer(64); - out.writeLong(aggregatedStatistics.checkpointId()); - statisticsSerializer.serialize(aggregatedStatistics.dataStatistics(), outSerializer); - byte[] statisticsBytes = outSerializer.getCopyOfBuffer(); - out.writeInt(statisticsBytes.length); - out.write(statisticsBytes); - out.flush(); - - return bytes.toByteArray(); - } - - static , S> - AggregatedStatistics deserializeAggregatedStatistics( - byte[] bytes, TypeSerializer> statisticsSerializer) - throws IOException { - ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytes); - ObjectInputStream in = new ObjectInputStream(bytesIn); - - long completedCheckpointId = in.readLong(); - int statisticsBytesLength = in.readInt(); - byte[] statisticsBytes = new byte[statisticsBytesLength]; - in.readFully(statisticsBytes); - DataInputDeserializer input = - new DataInputDeserializer(statisticsBytes, 0, statisticsBytesLength); - DataStatistics dataStatistics = statisticsSerializer.deserialize(input); - - return new AggregatedStatistics<>(completedCheckpointId, dataStatistics); - } -} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatistics.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatistics.java new file mode 100644 index 000000000000..50ec23e9f7a2 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatistics.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** + * This is used by {@link RangePartitioner} for guiding range partitioning. This is what is sent to + * the operator subtasks. For sketch statistics, it only contains much smaller range bounds than the + * complete raw samples. + */ +class GlobalStatistics { + private final long checkpointId; + private final StatisticsType type; + private final MapAssignment mapAssignment; + private final SortKey[] rangeBounds; + + private transient Integer hashCode; + + GlobalStatistics( + long checkpointId, StatisticsType type, MapAssignment mapAssignment, SortKey[] rangeBounds) { + Preconditions.checkArgument( + (mapAssignment != null && rangeBounds == null) + || (mapAssignment == null && rangeBounds != null), + "Invalid key assignment or range bounds: both are non-null or null"); + this.checkpointId = checkpointId; + this.type = type; + this.mapAssignment = mapAssignment; + this.rangeBounds = rangeBounds; + } + + static GlobalStatistics fromMapAssignment(long checkpointId, MapAssignment mapAssignment) { + return new GlobalStatistics(checkpointId, StatisticsType.Map, mapAssignment, null); + } + + static GlobalStatistics fromRangeBounds(long checkpointId, SortKey[] rangeBounds) { + return new GlobalStatistics(checkpointId, StatisticsType.Sketch, null, rangeBounds); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("checkpointId", checkpointId) + .add("type", type) + .add("mapAssignment", mapAssignment) + .add("rangeBounds", rangeBounds) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof GlobalStatistics)) { + return false; + } + + GlobalStatistics other = (GlobalStatistics) o; + return Objects.equal(checkpointId, other.checkpointId) + && Objects.equal(type, other.type) + && Objects.equal(mapAssignment, other.mapAssignment()) + && Arrays.equals(rangeBounds, other.rangeBounds()); + } + + @Override + public int hashCode() { + // implemented caching because coordinator can call the hashCode many times. + // when subtasks request statistics refresh upon initialization for reconciliation purpose, + // hashCode is used to check if there is any difference btw coordinator and operator state. + if (hashCode == null) { + this.hashCode = Objects.hashCode(checkpointId, type, mapAssignment, rangeBounds); + } + + return hashCode; + } + + long checkpointId() { + return checkpointId; + } + + StatisticsType type() { + return type; + } + + MapAssignment mapAssignment() { + return mapAssignment; + } + + SortKey[] rangeBounds() { + return rangeBounds; + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsSerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsSerializer.java new file mode 100644 index 000000000000..dfb947a84a0c --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsSerializer.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +class GlobalStatisticsSerializer extends TypeSerializer { + private final TypeSerializer sortKeySerializer; + private final EnumSerializer statisticsTypeSerializer; + private final ListSerializer rangeBoundsSerializer; + private final ListSerializer intsSerializer; + private final ListSerializer longsSerializer; + + GlobalStatisticsSerializer(TypeSerializer sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.rangeBoundsSerializer = new ListSerializer<>(sortKeySerializer); + this.intsSerializer = new ListSerializer<>(IntSerializer.INSTANCE); + this.longsSerializer = new ListSerializer<>(LongSerializer.INSTANCE); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return new GlobalStatisticsSerializer(sortKeySerializer); + } + + @Override + public GlobalStatistics createInstance() { + return GlobalStatistics.fromRangeBounds(0L, new SortKey[0]); + } + + @Override + public GlobalStatistics copy(GlobalStatistics from) { + return new GlobalStatistics( + from.checkpointId(), from.type(), from.mapAssignment(), from.rangeBounds()); + } + + @Override + public GlobalStatistics copy(GlobalStatistics from, GlobalStatistics reuse) { + // no benefit of reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(GlobalStatistics record, DataOutputView target) throws IOException { + target.writeLong(record.checkpointId()); + statisticsTypeSerializer.serialize(record.type(), target); + if (record.type() == StatisticsType.Map) { + MapAssignment mapAssignment = record.mapAssignment(); + target.writeInt(mapAssignment.numPartitions()); + target.writeInt(mapAssignment.keyAssignments().size()); + for (Map.Entry entry : mapAssignment.keyAssignments().entrySet()) { + sortKeySerializer.serialize(entry.getKey(), target); + KeyAssignment keyAssignment = entry.getValue(); + intsSerializer.serialize(keyAssignment.assignedSubtasks(), target); + longsSerializer.serialize(keyAssignment.subtaskWeightsWithCloseFileCost(), target); + target.writeLong(keyAssignment.closeFileCostWeight()); + } + } else { + rangeBoundsSerializer.serialize(Arrays.asList(record.rangeBounds()), target); + } + } + + @Override + public GlobalStatistics deserialize(DataInputView source) throws IOException { + long checkpointId = source.readLong(); + StatisticsType type = statisticsTypeSerializer.deserialize(source); + if (type == StatisticsType.Map) { + int numPartitions = source.readInt(); + int mapSize = source.readInt(); + Map keyAssignments = Maps.newHashMapWithExpectedSize(mapSize); + for (int i = 0; i < mapSize; ++i) { + SortKey sortKey = sortKeySerializer.deserialize(source); + List assignedSubtasks = intsSerializer.deserialize(source); + List subtaskWeightsWithCloseFileCost = longsSerializer.deserialize(source); + long closeFileCostWeight = source.readLong(); + keyAssignments.put( + sortKey, + new KeyAssignment( + assignedSubtasks, subtaskWeightsWithCloseFileCost, closeFileCostWeight)); + } + + return GlobalStatistics.fromMapAssignment( + checkpointId, new MapAssignment(numPartitions, keyAssignments)); + } else { + List sortKeys = rangeBoundsSerializer.deserialize(source); + SortKey[] rangeBounds = new SortKey[sortKeys.size()]; + return GlobalStatistics.fromRangeBounds(checkpointId, sortKeys.toArray(rangeBounds)); + } + } + + @Override + public GlobalStatistics deserialize(GlobalStatistics reuse, DataInputView source) + throws IOException { + // not much benefit to reuse + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + GlobalStatisticsSerializer other = (GlobalStatisticsSerializer) obj; + return Objects.equals(sortKeySerializer, other.sortKeySerializer); + } + + @Override + public int hashCode() { + return sortKeySerializer.hashCode(); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new GlobalStatisticsSerializerSnapshot(this); + } + + public static class GlobalStatisticsSerializerSnapshot + extends CompositeTypeSerializerSnapshot { + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public GlobalStatisticsSerializerSnapshot() { + super(GlobalStatisticsSerializer.class); + } + + @SuppressWarnings("checkstyle:RedundantModifier") + public GlobalStatisticsSerializerSnapshot(GlobalStatisticsSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers(GlobalStatisticsSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.sortKeySerializer}; + } + + @Override + protected GlobalStatisticsSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + SortKeySerializer sortKeySerializer = (SortKeySerializer) nestedSerializers[0]; + return new GlobalStatisticsSerializer(sortKeySerializer); + } + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/KeyAssignment.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/KeyAssignment.java new file mode 100644 index 000000000000..a164d83ac3b0 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/KeyAssignment.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** Subtask assignment for a key for Map statistics based */ +class KeyAssignment { + private final List assignedSubtasks; + private final List subtaskWeightsWithCloseFileCost; + private final long closeFileCostWeight; + private final long[] subtaskWeightsExcludingCloseCost; + private final long keyWeight; + private final long[] cumulativeWeights; + + /** + * @param assignedSubtasks assigned subtasks for this key. It could be a single subtask. It could + * also be multiple subtasks if the key has heavy weight that should be handled by multiple + * subtasks. + * @param subtaskWeightsWithCloseFileCost assigned weight for each subtask. E.g., if the keyWeight + * is 27 and the key is assigned to 3 subtasks, subtaskWeights could contain values as [10, + * 10, 7] for target weight of 10 per subtask. + */ + KeyAssignment( + List assignedSubtasks, + List subtaskWeightsWithCloseFileCost, + long closeFileCostWeight) { + Preconditions.checkArgument( + assignedSubtasks != null && !assignedSubtasks.isEmpty(), + "Invalid assigned subtasks: null or empty"); + Preconditions.checkArgument( + subtaskWeightsWithCloseFileCost != null && !subtaskWeightsWithCloseFileCost.isEmpty(), + "Invalid assigned subtasks weights: null or empty"); + Preconditions.checkArgument( + assignedSubtasks.size() == subtaskWeightsWithCloseFileCost.size(), + "Invalid assignment: size mismatch (tasks length = %s, weights length = %s)", + assignedSubtasks.size(), + subtaskWeightsWithCloseFileCost.size()); + subtaskWeightsWithCloseFileCost.forEach( + weight -> + Preconditions.checkArgument( + weight > closeFileCostWeight, + "Invalid weight: should be larger than close file cost: weight = %s, close file cost = %s", + weight, + closeFileCostWeight)); + + this.assignedSubtasks = assignedSubtasks; + this.subtaskWeightsWithCloseFileCost = subtaskWeightsWithCloseFileCost; + this.closeFileCostWeight = closeFileCostWeight; + // Exclude the close file cost for key routing + this.subtaskWeightsExcludingCloseCost = + subtaskWeightsWithCloseFileCost.stream() + .mapToLong(weightWithCloseFileCost -> weightWithCloseFileCost - closeFileCostWeight) + .toArray(); + this.keyWeight = Arrays.stream(subtaskWeightsExcludingCloseCost).sum(); + this.cumulativeWeights = new long[subtaskWeightsExcludingCloseCost.length]; + long cumulativeWeight = 0; + for (int i = 0; i < subtaskWeightsExcludingCloseCost.length; ++i) { + cumulativeWeight += subtaskWeightsExcludingCloseCost[i]; + cumulativeWeights[i] = cumulativeWeight; + } + } + + List assignedSubtasks() { + return assignedSubtasks; + } + + List subtaskWeightsWithCloseFileCost() { + return subtaskWeightsWithCloseFileCost; + } + + long closeFileCostWeight() { + return closeFileCostWeight; + } + + long[] subtaskWeightsExcludingCloseCost() { + return subtaskWeightsExcludingCloseCost; + } + + /** @return subtask id */ + int select() { + if (assignedSubtasks.size() == 1) { + // only choice. no need to run random number generator. + return assignedSubtasks.get(0); + } else { + long randomNumber = ThreadLocalRandom.current().nextLong(keyWeight); + int index = Arrays.binarySearch(cumulativeWeights, randomNumber); + // choose the subtask where randomNumber < cumulativeWeights[pos]. + // this works regardless whether index is negative or not. + int position = Math.abs(index + 1); + Preconditions.checkState( + position < assignedSubtasks.size(), + "Invalid selected position: out of range. key weight = %s, random number = %s, cumulative weights array = %s", + keyWeight, + randomNumber, + cumulativeWeights); + return assignedSubtasks.get(position); + } + } + + @Override + public int hashCode() { + return Objects.hash(assignedSubtasks, subtaskWeightsWithCloseFileCost, closeFileCostWeight); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + KeyAssignment that = (KeyAssignment) o; + return Objects.equals(assignedSubtasks, that.assignedSubtasks) + && Objects.equals(subtaskWeightsWithCloseFileCost, that.subtaskWeightsWithCloseFileCost) + && closeFileCostWeight == that.closeFileCostWeight; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("assignedSubtasks", assignedSubtasks) + .add("subtaskWeightsWithCloseFileCost", subtaskWeightsWithCloseFileCost) + .add("closeFileCostWeight", closeFileCostWeight) + .toString(); + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapAssignment.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapAssignment.java new file mode 100644 index 000000000000..0abb030c2279 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapAssignment.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Key assignment to subtasks for Map statistics. */ +class MapAssignment { + private static final Logger LOG = LoggerFactory.getLogger(MapAssignment.class); + + private final int numPartitions; + private final Map keyAssignments; + + MapAssignment(int numPartitions, Map keyAssignments) { + Preconditions.checkArgument(keyAssignments != null, "Invalid key assignments: null"); + this.numPartitions = numPartitions; + this.keyAssignments = keyAssignments; + } + + static MapAssignment fromKeyFrequency( + int numPartitions, + Map mapStatistics, + double closeFileCostWeightPercentage, + Comparator comparator) { + return new MapAssignment( + numPartitions, + assignment(numPartitions, mapStatistics, closeFileCostWeightPercentage, comparator)); + } + + @Override + public int hashCode() { + return Objects.hashCode(numPartitions, keyAssignments); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + MapAssignment that = (MapAssignment) o; + return numPartitions == that.numPartitions && keyAssignments.equals(that.keyAssignments); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("numPartitions", numPartitions) + .add("keyAssignments", keyAssignments) + .toString(); + } + + int numPartitions() { + return numPartitions; + } + + Map keyAssignments() { + return keyAssignments; + } + + /** + * @return assignment summary for every subtask. Key is subtaskId. Value pair is (weight assigned + * to the subtask, number of keys assigned to the subtask) + */ + Map> assignmentInfo() { + Map> assignmentInfo = Maps.newTreeMap(); + keyAssignments.forEach( + (key, keyAssignment) -> { + for (int i = 0; i < keyAssignment.assignedSubtasks().size(); ++i) { + int subtaskId = keyAssignment.assignedSubtasks().get(i); + long subtaskWeight = keyAssignment.subtaskWeightsExcludingCloseCost()[i]; + Pair oldValue = assignmentInfo.getOrDefault(subtaskId, Pair.of(0L, 0)); + assignmentInfo.put( + subtaskId, Pair.of(oldValue.first() + subtaskWeight, oldValue.second() + 1)); + } + }); + + return assignmentInfo; + } + + static Map assignment( + int numPartitions, + Map mapStatistics, + double closeFileCostWeightPercentage, + Comparator comparator) { + mapStatistics.forEach( + (key, value) -> + Preconditions.checkArgument( + value > 0, "Invalid statistics: weight is 0 for key %s", key)); + + long totalWeight = mapStatistics.values().stream().mapToLong(l -> l).sum(); + double targetWeightPerSubtask = ((double) totalWeight) / numPartitions; + long closeFileCostWeight = + (long) Math.ceil(targetWeightPerSubtask * closeFileCostWeightPercentage / 100); + + NavigableMap sortedStatsWithCloseFileCost = Maps.newTreeMap(comparator); + mapStatistics.forEach( + (k, v) -> { + int estimatedSplits = (int) Math.ceil(v / targetWeightPerSubtask); + long estimatedCloseFileCost = closeFileCostWeight * estimatedSplits; + sortedStatsWithCloseFileCost.put(k, v + estimatedCloseFileCost); + }); + + long totalWeightWithCloseFileCost = + sortedStatsWithCloseFileCost.values().stream().mapToLong(l -> l).sum(); + long targetWeightPerSubtaskWithCloseFileCost = + (long) Math.ceil(((double) totalWeightWithCloseFileCost) / numPartitions); + return buildAssignment( + numPartitions, + sortedStatsWithCloseFileCost, + targetWeightPerSubtaskWithCloseFileCost, + closeFileCostWeight); + } + + private static Map buildAssignment( + int numPartitions, + NavigableMap sortedStatistics, + long targetWeightPerSubtask, + long closeFileCostWeight) { + Map assignmentMap = + Maps.newHashMapWithExpectedSize(sortedStatistics.size()); + Iterator mapKeyIterator = sortedStatistics.keySet().iterator(); + int subtaskId = 0; + SortKey currentKey = null; + long keyRemainingWeight = 0L; + long subtaskRemainingWeight = targetWeightPerSubtask; + List assignedSubtasks = Lists.newArrayList(); + List subtaskWeights = Lists.newArrayList(); + while (mapKeyIterator.hasNext() || currentKey != null) { + // This should never happen because target weight is calculated using ceil function. + if (subtaskId >= numPartitions) { + LOG.error( + "Internal algorithm error: exhausted subtasks with unassigned keys left. number of partitions: {}, " + + "target weight per subtask: {}, close file cost in weight: {}, data statistics: {}", + numPartitions, + targetWeightPerSubtask, + closeFileCostWeight, + sortedStatistics); + throw new IllegalStateException( + "Internal algorithm error: exhausted subtasks with unassigned keys left"); + } + + if (currentKey == null) { + currentKey = mapKeyIterator.next(); + keyRemainingWeight = sortedStatistics.get(currentKey); + } + + assignedSubtasks.add(subtaskId); + if (keyRemainingWeight < subtaskRemainingWeight) { + // assign the remaining weight of the key to the current subtask + subtaskWeights.add(keyRemainingWeight); + subtaskRemainingWeight -= keyRemainingWeight; + keyRemainingWeight = 0L; + } else { + // filled up the current subtask + long assignedWeight = subtaskRemainingWeight; + keyRemainingWeight -= subtaskRemainingWeight; + + // If assigned weight is less than close file cost, pad it up with close file cost. + // This might cause the subtask assigned weight over the target weight. + // But it should be no more than one close file cost. Small skew is acceptable. + if (assignedWeight <= closeFileCostWeight) { + long paddingWeight = Math.min(keyRemainingWeight, closeFileCostWeight); + keyRemainingWeight -= paddingWeight; + assignedWeight += paddingWeight; + } + + subtaskWeights.add(assignedWeight); + // move on to the next subtask + subtaskId += 1; + subtaskRemainingWeight = targetWeightPerSubtask; + } + + Preconditions.checkState( + assignedSubtasks.size() == subtaskWeights.size(), + "List size mismatch: assigned subtasks = %s, subtask weights = %s", + assignedSubtasks, + subtaskWeights); + + // If the remaining key weight is smaller than the close file cost, simply skip the residual + // as it doesn't make sense to assign a weight smaller than close file cost to a new subtask. + // this might lead to some inaccuracy in weight calculation. E.g., assuming the key weight is + // 2 and close file cost is 2. key weight with close cost is 4. Let's assume the previous + // task has a weight of 3 available. So weight of 3 for this key is assigned to the task and + // the residual weight of 1 is dropped. Then the routing weight for this key is 1 (minus the + // close file cost), which is inaccurate as the true key weight should be 2. + // Again, this greedy algorithm is not intended to be perfect. Some small inaccuracy is + // expected and acceptable. Traffic distribution should still be balanced. + if (keyRemainingWeight > 0 && keyRemainingWeight <= closeFileCostWeight) { + keyRemainingWeight = 0; + } + + if (keyRemainingWeight == 0) { + // finishing up the assignment for the current key + KeyAssignment keyAssignment = + new KeyAssignment(assignedSubtasks, subtaskWeights, closeFileCostWeight); + assignmentMap.put(currentKey, keyAssignment); + assignedSubtasks = Lists.newArrayList(); + subtaskWeights = Lists.newArrayList(); + currentKey = null; + } + } + + return assignmentMap; + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java index 0b63e2721178..05b943f6046f 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java @@ -19,52 +19,70 @@ package org.apache.iceberg.flink.sink.shuffle; import java.util.Map; -import org.apache.flink.annotation.Internal; import org.apache.iceberg.SortKey; import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; import org.apache.iceberg.relocated.com.google.common.collect.Maps; /** MapDataStatistics uses map to count key frequency */ -@Internal -class MapDataStatistics implements DataStatistics> { - private final Map statistics; +class MapDataStatistics implements DataStatistics { + private final Map keyFrequency; MapDataStatistics() { - this.statistics = Maps.newHashMap(); + this.keyFrequency = Maps.newHashMap(); } - MapDataStatistics(Map statistics) { - this.statistics = statistics; + MapDataStatistics(Map keyFrequency) { + this.keyFrequency = keyFrequency; + } + + @Override + public StatisticsType type() { + return StatisticsType.Map; } @Override public boolean isEmpty() { - return statistics.isEmpty(); + return keyFrequency.isEmpty(); } @Override public void add(SortKey sortKey) { - if (statistics.containsKey(sortKey)) { - statistics.merge(sortKey, 1L, Long::sum); + if (keyFrequency.containsKey(sortKey)) { + keyFrequency.merge(sortKey, 1L, Long::sum); } else { // clone the sort key before adding to map because input sortKey object can be reused SortKey copiedKey = sortKey.copy(); - statistics.put(copiedKey, 1L); + keyFrequency.put(copiedKey, 1L); } } @Override - public void merge(MapDataStatistics otherStatistics) { - otherStatistics.statistics().forEach((key, count) -> statistics.merge(key, count, Long::sum)); + public Object result() { + return keyFrequency; } @Override - public Map statistics() { - return statistics; + public String toString() { + return MoreObjects.toStringHelper(this).add("map", keyFrequency).toString(); } @Override - public String toString() { - return MoreObjects.toStringHelper(this).add("statistics", statistics).toString(); + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof MapDataStatistics)) { + return false; + } + + MapDataStatistics other = (MapDataStatistics) o; + return Objects.equal(keyFrequency, other.keyFrequency); + } + + @Override + public int hashCode() { + return Objects.hashCode(keyFrequency); } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java deleted file mode 100644 index b6cccd0566fc..000000000000 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import java.io.IOException; -import java.util.Map; -import java.util.Objects; -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.base.LongSerializer; -import org.apache.flink.api.common.typeutils.base.MapSerializer; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.util.Preconditions; -import org.apache.iceberg.SortKey; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; - -@Internal -class MapDataStatisticsSerializer - extends TypeSerializer>> { - private final MapSerializer mapSerializer; - - static MapDataStatisticsSerializer fromSortKeySerializer( - TypeSerializer sortKeySerializer) { - return new MapDataStatisticsSerializer( - new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE)); - } - - MapDataStatisticsSerializer(MapSerializer mapSerializer) { - this.mapSerializer = mapSerializer; - } - - @Override - public boolean isImmutableType() { - return false; - } - - @SuppressWarnings("ReferenceEquality") - @Override - public TypeSerializer>> duplicate() { - MapSerializer duplicateMapSerializer = - (MapSerializer) mapSerializer.duplicate(); - return (duplicateMapSerializer == mapSerializer) - ? this - : new MapDataStatisticsSerializer(duplicateMapSerializer); - } - - @Override - public MapDataStatistics createInstance() { - return new MapDataStatistics(); - } - - @Override - public MapDataStatistics copy(DataStatistics> obj) { - Preconditions.checkArgument( - obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass()); - MapDataStatistics from = (MapDataStatistics) obj; - TypeSerializer keySerializer = mapSerializer.getKeySerializer(); - Map newMap = Maps.newHashMapWithExpectedSize(from.statistics().size()); - for (Map.Entry entry : from.statistics().entrySet()) { - SortKey newKey = keySerializer.copy(entry.getKey()); - // no need to copy value since it is just a Long - newMap.put(newKey, entry.getValue()); - } - - return new MapDataStatistics(newMap); - } - - @Override - public DataStatistics> copy( - DataStatistics> from, - DataStatistics> reuse) { - // not much benefit to reuse - return copy(from); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize( - DataStatistics> obj, DataOutputView target) - throws IOException { - Preconditions.checkArgument( - obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass()); - MapDataStatistics mapStatistics = (MapDataStatistics) obj; - mapSerializer.serialize(mapStatistics.statistics(), target); - } - - @Override - public DataStatistics> deserialize(DataInputView source) - throws IOException { - return new MapDataStatistics(mapSerializer.deserialize(source)); - } - - @Override - public DataStatistics> deserialize( - DataStatistics> reuse, DataInputView source) - throws IOException { - // not much benefit to reuse - return deserialize(source); - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - mapSerializer.copy(source, target); - } - - @Override - public boolean equals(Object obj) { - if (!(obj instanceof MapDataStatisticsSerializer)) { - return false; - } - - MapDataStatisticsSerializer other = (MapDataStatisticsSerializer) obj; - return Objects.equals(mapSerializer, other.mapSerializer); - } - - @Override - public int hashCode() { - return mapSerializer.hashCode(); - } - - @Override - public TypeSerializerSnapshot>> - snapshotConfiguration() { - return new MapDataStatisticsSerializerSnapshot(this); - } - - public static class MapDataStatisticsSerializerSnapshot - extends CompositeTypeSerializerSnapshot< - DataStatistics>, MapDataStatisticsSerializer> { - private static final int CURRENT_VERSION = 1; - - // constructors need to public. Otherwise, Flink state restore would complain - // "The class has no (implicit) public nullary constructor". - @SuppressWarnings("checkstyle:RedundantModifier") - public MapDataStatisticsSerializerSnapshot() { - super(MapDataStatisticsSerializer.class); - } - - @SuppressWarnings("checkstyle:RedundantModifier") - public MapDataStatisticsSerializerSnapshot(MapDataStatisticsSerializer serializer) { - super(serializer); - } - - @Override - protected int getCurrentOuterSnapshotVersion() { - return CURRENT_VERSION; - } - - @Override - protected TypeSerializer[] getNestedSerializers( - MapDataStatisticsSerializer outerSerializer) { - return new TypeSerializer[] {outerSerializer.mapSerializer}; - } - - @Override - protected MapDataStatisticsSerializer createOuterSerializerWithNestedSerializers( - TypeSerializer[] nestedSerializers) { - @SuppressWarnings("unchecked") - MapSerializer mapSerializer = - (MapSerializer) nestedSerializers[0]; - return new MapDataStatisticsSerializer(mapSerializer); - } - } -} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java index dde86b5b6047..f36a078c94e0 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java @@ -18,29 +18,14 @@ */ package org.apache.iceberg.flink.sink.shuffle; -import java.util.Arrays; -import java.util.Comparator; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.NavigableMap; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.table.data.RowData; import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; import org.apache.iceberg.SortOrder; -import org.apache.iceberg.SortOrderComparators; -import org.apache.iceberg.StructLike; import org.apache.iceberg.flink.FlinkSchemaUtil; import org.apache.iceberg.flink.RowDataWrapper; -import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; -import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.util.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,52 +46,28 @@ class MapRangePartitioner implements Partitioner { private final RowDataWrapper rowDataWrapper; private final SortKey sortKey; - private final Comparator comparator; - private final Map mapStatistics; - private final double closeFileCostInWeightPercentage; + private final MapAssignment mapAssignment; // Counter that tracks how many times a new key encountered // where there is no traffic statistics learned about it. private long newSortKeyCounter; private long lastNewSortKeyLogTimeMilli; - // lazily computed due to the need of numPartitions - private Map assignment; - private NavigableMap sortedStatsWithCloseFileCost; - - MapRangePartitioner( - Schema schema, - SortOrder sortOrder, - MapDataStatistics dataStatistics, - double closeFileCostInWeightPercentage) { - dataStatistics - .statistics() - .entrySet() - .forEach( - entry -> - Preconditions.checkArgument( - entry.getValue() > 0, - "Invalid statistics: weight is 0 for key %s", - entry.getKey())); - + MapRangePartitioner(Schema schema, SortOrder sortOrder, MapAssignment mapAssignment) { this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); this.sortKey = new SortKey(schema, sortOrder); - this.comparator = SortOrderComparators.forSchema(schema, sortOrder); - this.mapStatistics = dataStatistics.statistics(); - this.closeFileCostInWeightPercentage = closeFileCostInWeightPercentage; + this.mapAssignment = mapAssignment; this.newSortKeyCounter = 0; this.lastNewSortKeyLogTimeMilli = System.currentTimeMillis(); } @Override public int partition(RowData row, int numPartitions) { - // assignment table can only be built lazily when first referenced here, - // because number of partitions (downstream subtasks) is needed. - // the numPartitions is not available in the constructor. - Map assignmentMap = assignment(numPartitions); // reuse the sortKey and rowDataWrapper sortKey.wrap(rowDataWrapper.wrap(row)); - KeyAssignment keyAssignment = assignmentMap.get(sortKey); + KeyAssignment keyAssignment = mapAssignment.keyAssignments().get(sortKey); + + int partition; if (keyAssignment == null) { LOG.trace( "Encountered new sort key: {}. Fall back to round robin as statistics not learned yet.", @@ -117,271 +78,18 @@ public int partition(RowData row, int numPartitions) { newSortKeyCounter += 1; long now = System.currentTimeMillis(); if (now - lastNewSortKeyLogTimeMilli > TimeUnit.MINUTES.toMillis(1)) { - LOG.info("Encounter new sort keys in total {} times", newSortKeyCounter); + LOG.info( + "Encounter new sort keys {} times. Fall back to round robin as statistics not learned yet", + newSortKeyCounter); lastNewSortKeyLogTimeMilli = now; + newSortKeyCounter = 0; } - return (int) (newSortKeyCounter % numPartitions); + partition = (int) (newSortKeyCounter % numPartitions); + } else { + partition = keyAssignment.select(); } - return keyAssignment.select(); - } - - @VisibleForTesting - Map assignment(int numPartitions) { - if (assignment == null) { - long totalWeight = mapStatistics.values().stream().mapToLong(l -> l).sum(); - double targetWeightPerSubtask = ((double) totalWeight) / numPartitions; - long closeFileCostInWeight = - (long) Math.ceil(targetWeightPerSubtask * closeFileCostInWeightPercentage / 100); - - this.sortedStatsWithCloseFileCost = Maps.newTreeMap(comparator); - mapStatistics.forEach( - (k, v) -> { - int estimatedSplits = (int) Math.ceil(v / targetWeightPerSubtask); - long estimatedCloseFileCost = closeFileCostInWeight * estimatedSplits; - sortedStatsWithCloseFileCost.put(k, v + estimatedCloseFileCost); - }); - - long totalWeightWithCloseFileCost = - sortedStatsWithCloseFileCost.values().stream().mapToLong(l -> l).sum(); - long targetWeightPerSubtaskWithCloseFileCost = - (long) Math.ceil(((double) totalWeightWithCloseFileCost) / numPartitions); - this.assignment = - buildAssignment( - numPartitions, - sortedStatsWithCloseFileCost, - targetWeightPerSubtaskWithCloseFileCost, - closeFileCostInWeight); - } - - return assignment; - } - - @VisibleForTesting - Map mapStatistics() { - return mapStatistics; - } - - /** - * Returns assignment summary for every subtask. - * - * @return assignment summary for every subtask. Key is subtaskId. Value pair is (weight assigned - * to the subtask, number of keys assigned to the subtask) - */ - Map> assignmentInfo() { - Map> assignmentInfo = Maps.newTreeMap(); - assignment.forEach( - (key, keyAssignment) -> { - for (int i = 0; i < keyAssignment.assignedSubtasks.length; ++i) { - int subtaskId = keyAssignment.assignedSubtasks[i]; - long subtaskWeight = keyAssignment.subtaskWeightsExcludingCloseCost[i]; - Pair oldValue = assignmentInfo.getOrDefault(subtaskId, Pair.of(0L, 0)); - assignmentInfo.put( - subtaskId, Pair.of(oldValue.first() + subtaskWeight, oldValue.second() + 1)); - } - }); - - return assignmentInfo; - } - - private Map buildAssignment( - int numPartitions, - NavigableMap sortedStatistics, - long targetWeightPerSubtask, - long closeFileCostInWeight) { - Map assignmentMap = - Maps.newHashMapWithExpectedSize(sortedStatistics.size()); - Iterator mapKeyIterator = sortedStatistics.keySet().iterator(); - int subtaskId = 0; - SortKey currentKey = null; - long keyRemainingWeight = 0L; - long subtaskRemainingWeight = targetWeightPerSubtask; - List assignedSubtasks = Lists.newArrayList(); - List subtaskWeights = Lists.newArrayList(); - while (mapKeyIterator.hasNext() || currentKey != null) { - // This should never happen because target weight is calculated using ceil function. - if (subtaskId >= numPartitions) { - LOG.error( - "Internal algorithm error: exhausted subtasks with unassigned keys left. number of partitions: {}, " - + "target weight per subtask: {}, close file cost in weight: {}, data statistics: {}", - numPartitions, - targetWeightPerSubtask, - closeFileCostInWeight, - sortedStatistics); - throw new IllegalStateException( - "Internal algorithm error: exhausted subtasks with unassigned keys left"); - } - - if (currentKey == null) { - currentKey = mapKeyIterator.next(); - keyRemainingWeight = sortedStatistics.get(currentKey); - } - - assignedSubtasks.add(subtaskId); - if (keyRemainingWeight < subtaskRemainingWeight) { - // assign the remaining weight of the key to the current subtask - subtaskWeights.add(keyRemainingWeight); - subtaskRemainingWeight -= keyRemainingWeight; - keyRemainingWeight = 0L; - } else { - // filled up the current subtask - long assignedWeight = subtaskRemainingWeight; - keyRemainingWeight -= subtaskRemainingWeight; - - // If assigned weight is less than close file cost, pad it up with close file cost. - // This might cause the subtask assigned weight over the target weight. - // But it should be no more than one close file cost. Small skew is acceptable. - if (assignedWeight <= closeFileCostInWeight) { - long paddingWeight = Math.min(keyRemainingWeight, closeFileCostInWeight); - keyRemainingWeight -= paddingWeight; - assignedWeight += paddingWeight; - } - - subtaskWeights.add(assignedWeight); - // move on to the next subtask - subtaskId += 1; - subtaskRemainingWeight = targetWeightPerSubtask; - } - - Preconditions.checkState( - assignedSubtasks.size() == subtaskWeights.size(), - "List size mismatch: assigned subtasks = %s, subtask weights = %s", - assignedSubtasks, - subtaskWeights); - - // If the remaining key weight is smaller than the close file cost, simply skip the residual - // as it doesn't make sense to assign a weight smaller than close file cost to a new subtask. - // this might lead to some inaccuracy in weight calculation. E.g., assuming the key weight is - // 2 and close file cost is 2. key weight with close cost is 4. Let's assume the previous - // task has a weight of 3 available. So weight of 3 for this key is assigned to the task and - // the residual weight of 1 is dropped. Then the routing weight for this key is 1 (minus the - // close file cost), which is inaccurate as the true key weight should be 2. - // Again, this greedy algorithm is not intended to be perfect. Some small inaccuracy is - // expected and acceptable. Traffic distribution should still be balanced. - if (keyRemainingWeight > 0 && keyRemainingWeight <= closeFileCostInWeight) { - keyRemainingWeight = 0; - } - - if (keyRemainingWeight == 0) { - // finishing up the assignment for the current key - KeyAssignment keyAssignment = - new KeyAssignment(assignedSubtasks, subtaskWeights, closeFileCostInWeight); - assignmentMap.put(currentKey, keyAssignment); - assignedSubtasks.clear(); - subtaskWeights.clear(); - currentKey = null; - } - } - - return assignmentMap; - } - - /** Subtask assignment for a key */ - @VisibleForTesting - static class KeyAssignment { - private final int[] assignedSubtasks; - private final long[] subtaskWeightsExcludingCloseCost; - private final long keyWeight; - private final long[] cumulativeWeights; - - /** - * @param assignedSubtasks assigned subtasks for this key. It could be a single subtask. It - * could also be multiple subtasks if the key has heavy weight that should be handled by - * multiple subtasks. - * @param subtaskWeightsWithCloseFileCost assigned weight for each subtask. E.g., if the - * keyWeight is 27 and the key is assigned to 3 subtasks, subtaskWeights could contain - * values as [10, 10, 7] for target weight of 10 per subtask. - */ - KeyAssignment( - List assignedSubtasks, - List subtaskWeightsWithCloseFileCost, - long closeFileCostInWeight) { - Preconditions.checkArgument( - assignedSubtasks != null && !assignedSubtasks.isEmpty(), - "Invalid assigned subtasks: null or empty"); - Preconditions.checkArgument( - subtaskWeightsWithCloseFileCost != null && !subtaskWeightsWithCloseFileCost.isEmpty(), - "Invalid assigned subtasks weights: null or empty"); - Preconditions.checkArgument( - assignedSubtasks.size() == subtaskWeightsWithCloseFileCost.size(), - "Invalid assignment: size mismatch (tasks length = %s, weights length = %s)", - assignedSubtasks.size(), - subtaskWeightsWithCloseFileCost.size()); - subtaskWeightsWithCloseFileCost.forEach( - weight -> - Preconditions.checkArgument( - weight > closeFileCostInWeight, - "Invalid weight: should be larger than close file cost: weight = %s, close file cost = %s", - weight, - closeFileCostInWeight)); - - this.assignedSubtasks = assignedSubtasks.stream().mapToInt(i -> i).toArray(); - // Exclude the close file cost for key routing - this.subtaskWeightsExcludingCloseCost = - subtaskWeightsWithCloseFileCost.stream() - .mapToLong(weightWithCloseFileCost -> weightWithCloseFileCost - closeFileCostInWeight) - .toArray(); - this.keyWeight = Arrays.stream(subtaskWeightsExcludingCloseCost).sum(); - this.cumulativeWeights = new long[subtaskWeightsExcludingCloseCost.length]; - long cumulativeWeight = 0; - for (int i = 0; i < subtaskWeightsExcludingCloseCost.length; ++i) { - cumulativeWeight += subtaskWeightsExcludingCloseCost[i]; - cumulativeWeights[i] = cumulativeWeight; - } - } - - /** - * Select a subtask for the key. - * - * @return subtask id - */ - int select() { - if (assignedSubtasks.length == 1) { - // only choice. no need to run random number generator. - return assignedSubtasks[0]; - } else { - long randomNumber = ThreadLocalRandom.current().nextLong(keyWeight); - int index = Arrays.binarySearch(cumulativeWeights, randomNumber); - // choose the subtask where randomNumber < cumulativeWeights[pos]. - // this works regardless whether index is negative or not. - int position = Math.abs(index + 1); - Preconditions.checkState( - position < assignedSubtasks.length, - "Invalid selected position: out of range. key weight = %s, random number = %s, cumulative weights array = %s", - keyWeight, - randomNumber, - cumulativeWeights); - return assignedSubtasks[position]; - } - } - - @Override - public int hashCode() { - return 31 * Arrays.hashCode(assignedSubtasks) - + Arrays.hashCode(subtaskWeightsExcludingCloseCost); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - - if (o == null || getClass() != o.getClass()) { - return false; - } - - KeyAssignment that = (KeyAssignment) o; - return Arrays.equals(assignedSubtasks, that.assignedSubtasks) - && Arrays.equals(subtaskWeightsExcludingCloseCost, that.subtaskWeightsExcludingCloseCost); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("assignedSubtasks", assignedSubtasks) - .add("subtaskWeightsExcludingCloseCost", subtaskWeightsExcludingCloseCost) - .toString(); - } + return RangePartitioner.adjustPartitionWithRescale( + partition, mapAssignment.numPartitions(), numPartitions); } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RangePartitioner.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RangePartitioner.java new file mode 100644 index 000000000000..83a9461233d2 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RangePartitioner.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.table.data.RowData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** The wrapper class */ +@Internal +public class RangePartitioner implements Partitioner { + private static final Logger LOG = LoggerFactory.getLogger(RangePartitioner.class); + + private final Schema schema; + private final SortOrder sortOrder; + + private transient AtomicLong roundRobinCounter; + private transient Partitioner delegatePartitioner; + + public RangePartitioner(Schema schema, SortOrder sortOrder) { + this.schema = schema; + this.sortOrder = sortOrder; + } + + @Override + public int partition(StatisticsOrRecord wrapper, int numPartitions) { + if (wrapper.hasStatistics()) { + this.delegatePartitioner = delegatePartitioner(wrapper.statistics()); + return (int) (roundRobinCounter(numPartitions).getAndIncrement() % numPartitions); + } else { + if (delegatePartitioner != null) { + return delegatePartitioner.partition(wrapper.record(), numPartitions); + } else { + int partition = (int) (roundRobinCounter(numPartitions).getAndIncrement() % numPartitions); + LOG.trace("Statistics not available. Round robin to partition {}", partition); + return partition; + } + } + } + + private AtomicLong roundRobinCounter(int numPartitions) { + if (roundRobinCounter == null) { + // randomize the starting point to avoid synchronization across subtasks + this.roundRobinCounter = new AtomicLong(new Random().nextInt(numPartitions)); + } + + return roundRobinCounter; + } + + private Partitioner delegatePartitioner(GlobalStatistics statistics) { + if (statistics.type() == StatisticsType.Map) { + return new MapRangePartitioner(schema, sortOrder, statistics.mapAssignment()); + } else if (statistics.type() == StatisticsType.Sketch) { + return new SketchRangePartitioner(schema, sortOrder, statistics.rangeBounds()); + } else { + throw new IllegalArgumentException( + String.format("Invalid statistics type: %s. Should be Map or Sketch", statistics.type())); + } + } + + /** + * Util method that handles rescale (write parallelism / numPartitions change). + * + * @param partition partition caculated based on the existing statistics + * @param numPartitionsStatsCalculation number of partitions when the assignment was calculated + * based on + * @param numPartitions current number of partitions + * @return adjusted partition if necessary. + */ + static int adjustPartitionWithRescale( + int partition, int numPartitionsStatsCalculation, int numPartitions) { + if (numPartitionsStatsCalculation <= numPartitions) { + // no rescale or scale-up case. + // new subtasks are ignored and not assigned any keys, which is sub-optimal and only + // transient. + // when rescale is detected, operator requests new statistics from coordinator upon + // initialization. + return partition; + } else { + // scale-down case. + // Use mod % operation to distribution the over-range partitions. + // It can cause skew among subtasks. but the behavior is still better than + // discarding the statistics and falling back to round-robin (no clustering). + // Again, this is transient and stats refresh is requested when rescale is detected. + return partition % numPartitions; + } + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RequestGlobalStatisticsEvent.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RequestGlobalStatisticsEvent.java new file mode 100644 index 000000000000..ce17e1964392 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RequestGlobalStatisticsEvent.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.runtime.operators.coordination.OperatorEvent; + +class RequestGlobalStatisticsEvent implements OperatorEvent { + private final Integer signature; + + RequestGlobalStatisticsEvent() { + this.signature = null; + } + + /** @param signature hashCode of the subtask's existing global statistics */ + RequestGlobalStatisticsEvent(int signature) { + this.signature = signature; + } + + Integer signature() { + return signature; + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchDataStatistics.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchDataStatistics.java new file mode 100644 index 000000000000..35bbb27baf16 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchDataStatistics.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +/** MapDataStatistics uses map to count key frequency */ +class SketchDataStatistics implements DataStatistics { + + private final ReservoirItemsSketch sketch; + + SketchDataStatistics(int reservoirSize) { + this.sketch = ReservoirItemsSketch.newInstance(reservoirSize); + } + + SketchDataStatistics(ReservoirItemsSketch sketchStats) { + this.sketch = sketchStats; + } + + @Override + public StatisticsType type() { + return StatisticsType.Sketch; + } + + @Override + public boolean isEmpty() { + return sketch.getNumSamples() == 0; + } + + @Override + public void add(SortKey sortKey) { + // clone the sort key first because input sortKey object can be reused + SortKey copiedKey = sortKey.copy(); + sketch.update(copiedKey); + } + + @Override + public Object result() { + return sketch; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("sketch", sketch).toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof SketchDataStatistics)) { + return false; + } + + ReservoirItemsSketch otherSketch = ((SketchDataStatistics) o).sketch; + return Objects.equal(sketch.getK(), otherSketch.getK()) + && Objects.equal(sketch.getN(), otherSketch.getN()) + && Arrays.deepEquals(sketch.getSamples(), otherSketch.getSamples()); + } + + @Override + public int hashCode() { + return Objects.hashCode(sketch.getK(), sketch.getN(), sketch.getSamples()); + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitioner.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitioner.java new file mode 100644 index 000000000000..af78271ea5dc --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitioner.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.Comparator; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.table.data.RowData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.flink.FlinkSchemaUtil; +import org.apache.iceberg.flink.RowDataWrapper; + +class SketchRangePartitioner implements Partitioner { + private final SortKey sortKey; + private final Comparator comparator; + private final SortKey[] rangeBounds; + private final RowDataWrapper rowDataWrapper; + + SketchRangePartitioner(Schema schema, SortOrder sortOrder, SortKey[] rangeBounds) { + this.sortKey = new SortKey(schema, sortOrder); + this.comparator = SortOrderComparators.forSchema(schema, sortOrder); + this.rangeBounds = rangeBounds; + this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); + } + + @Override + public int partition(RowData row, int numPartitions) { + // reuse the sortKey and rowDataWrapper + sortKey.wrap(rowDataWrapper.wrap(row)); + int partition = Arrays.binarySearch(rangeBounds, sortKey, comparator); + + // binarySearch either returns the match location or -[insertion point]-1 + if (partition < 0) { + partition = -partition - 1; + } + + if (partition > rangeBounds.length) { + partition = rangeBounds.length; + } + + return RangePartitioner.adjustPartitionWithRescale( + partition, rangeBounds.length + 1, numPartitions); + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchUtil.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchUtil.java new file mode 100644 index 000000000000..a58310611e8d --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchUtil.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Map; +import java.util.function.Consumer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.StructLike; + +class SketchUtil { + static final int COORDINATOR_MIN_RESERVOIR_SIZE = 10_000; + static final int COORDINATOR_MAX_RESERVOIR_SIZE = 1_000_000; + static final int COORDINATOR_TARGET_PARTITIONS_MULTIPLIER = 100; + static final int OPERATOR_OVER_SAMPLE_RATIO = 10; + + // switch the statistics tracking from map to sketch if the cardinality of the sort key is over + // this threshold. It is hardcoded for now, we can revisit in the future if config is needed. + static final int OPERATOR_SKETCH_SWITCH_THRESHOLD = 10_000; + static final int COORDINATOR_SKETCH_SWITCH_THRESHOLD = 100_000; + + private SketchUtil() {} + + /** + * The larger the reservoir size, the more accurate for range bounds calculation and the more + * balanced range distribution. + * + *

Here are the heuristic rules + *

  • Target size: numPartitions x 100 to achieve good accuracy and is easier to calculate the + * range bounds + *
  • Min is 10K to achieve good accuracy while memory footprint is still relatively small + *
  • Max is 1M to cap the memory footprint on coordinator + * + * @param numPartitions number of range partitions which equals to downstream operator parallelism + * @return reservoir size + */ + static int determineCoordinatorReservoirSize(int numPartitions) { + int reservoirSize = numPartitions * COORDINATOR_TARGET_PARTITIONS_MULTIPLIER; + + if (reservoirSize < COORDINATOR_MIN_RESERVOIR_SIZE) { + // adjust it up and still make reservoirSize divisible by numPartitions + int remainder = COORDINATOR_MIN_RESERVOIR_SIZE % numPartitions; + reservoirSize = COORDINATOR_MIN_RESERVOIR_SIZE + (numPartitions - remainder); + } else if (reservoirSize > COORDINATOR_MAX_RESERVOIR_SIZE) { + // adjust it down and still make reservoirSize divisible by numPartitions + int remainder = COORDINATOR_MAX_RESERVOIR_SIZE % numPartitions; + reservoirSize = COORDINATOR_MAX_RESERVOIR_SIZE - remainder; + } + + return reservoirSize; + } + + /** + * Determine the sampling reservoir size where operator subtasks collect data statistics. + * + *

    Here are the heuristic rules + *

  • Target size is "coordinator reservoir size * over sampling ration (10) / operator + * parallelism" + *
  • Min is 1K to achieve good accuracy while memory footprint is still relatively small + *
  • Max is 100K to cap the memory footprint on coordinator + * + * @param numPartitions number of range partitions which equals to downstream operator parallelism + * @param operatorParallelism data statistics operator parallelism + * @return reservoir size + */ + static int determineOperatorReservoirSize(int operatorParallelism, int numPartitions) { + int coordinatorReservoirSize = determineCoordinatorReservoirSize(numPartitions); + int totalOperatorSamples = coordinatorReservoirSize * OPERATOR_OVER_SAMPLE_RATIO; + return (int) Math.ceil((double) totalOperatorSamples / operatorParallelism); + } + + /** + * To understand how range bounds are used in range partitioning, here is an example for human + * ages with 4 partitions: [15, 32, 60]. The 4 ranges would be + * + *
      + *
    • age <= 15 + *
    • age > 15 && age <= 32 + *
    • age >32 && age <= 60 + *
    • age > 60 + *
    + * + *

    Assumption is that a single key is not dominant enough to span multiple subtasks. + * + * @param numPartitions number of partitions which maps to downstream operator parallelism + * @param samples sampled keys + * @return array of range partition bounds. It should be a sorted list (ascending). Number of + * items should be {@code numPartitions - 1}. if numPartitions is 1, return an empty list + */ + static SortKey[] rangeBounds( + int numPartitions, Comparator comparator, SortKey[] samples) { + // sort the keys first + Arrays.sort(samples, comparator); + int numCandidates = numPartitions - 1; + SortKey[] candidates = new SortKey[numCandidates]; + int step = (int) Math.ceil((double) samples.length / numPartitions); + int position = step - 1; + int numChosen = 0; + while (position < samples.length && numChosen < numCandidates) { + SortKey candidate = samples[position]; + // skip duplicate values + if (numChosen > 0 && candidate.equals(candidates[numChosen - 1])) { + // linear probe for the next distinct value + position += 1; + } else { + candidates[numChosen] = candidate; + position += step; + numChosen += 1; + } + } + + return candidates; + } + + /** This can be a bit expensive since it is quadratic. */ + static void convertMapToSketch( + Map taskMapStats, Consumer sketchConsumer) { + taskMapStats.forEach( + (sortKey, count) -> { + for (int i = 0; i < count; ++i) { + sketchConsumer.accept(sortKey); + } + }); + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java index d03409f2a430..4ddc5a32d6bf 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java @@ -276,13 +276,12 @@ public static class SortKeySerializerSnapshot implements TypeSerializerSnapshot< private Schema schema; private SortOrder sortOrder; - @SuppressWarnings({"checkstyle:RedundantModifier", "WeakerAccess"}) + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) public SortKeySerializerSnapshot() { // this constructor is used when restoring from a checkpoint. } - // constructors need to public. Otherwise, Flink state restore would complain - // "The class has no (implicit) public nullary constructor". @SuppressWarnings("checkstyle:RedundantModifier") public SortKeySerializerSnapshot(Schema schema, SortOrder sortOrder) { this.schema = schema; @@ -320,8 +319,12 @@ public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( return TypeSerializerSchemaCompatibility.incompatible(); } - SortKeySerializer newAvroSerializer = (SortKeySerializer) newSerializer; - return resolveSchemaCompatibility(newAvroSerializer.schema, schema); + SortKeySerializer sortKeySerializer = (SortKeySerializer) newSerializer; + if (!sortOrder.sameOrder(sortKeySerializer.sortOrder)) { + return TypeSerializerSchemaCompatibility.incompatible(); + } + + return resolveSchemaCompatibility(sortKeySerializer.schema, schema); } @Override diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySketchSerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySketchSerializer.java new file mode 100644 index 000000000000..d6c23f035015 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySketchSerializer.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.io.Serializable; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.List; +import org.apache.datasketches.common.ArrayOfItemsSerDe; +import org.apache.datasketches.common.ArrayOfStringsSerDe; +import org.apache.datasketches.common.ByteArrayUtil; +import org.apache.datasketches.common.Util; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** + * Only way to implement {@link ReservoirItemsSketch} serializer is to extend from {@link + * ArrayOfItemsSerDe}, as deserialization uses a private constructor from ReservoirItemsSketch. The + * implementation is modeled after {@link ArrayOfStringsSerDe} + */ +class SortKeySketchSerializer extends ArrayOfItemsSerDe implements Serializable { + private static final int DEFAULT_SORT_KEY_SIZE = 128; + + private final TypeSerializer itemSerializer; + private final ListSerializer listSerializer; + private final DataInputDeserializer input; + + SortKeySketchSerializer(TypeSerializer itemSerializer) { + this.itemSerializer = itemSerializer; + this.listSerializer = new ListSerializer<>(itemSerializer); + this.input = new DataInputDeserializer(); + } + + @Override + public byte[] serializeToByteArray(SortKey item) { + try { + DataOutputSerializer output = new DataOutputSerializer(DEFAULT_SORT_KEY_SIZE); + itemSerializer.serialize(item, output); + byte[] itemBytes = output.getSharedBuffer(); + int numBytes = output.length(); + byte[] out = new byte[numBytes + Integer.BYTES]; + ByteArrayUtil.copyBytes(itemBytes, 0, out, 4, numBytes); + ByteArrayUtil.putIntLE(out, 0, numBytes); + return out; + } catch (IOException e) { + throw new UncheckedIOException("Failed to serialize sort key", e); + } + } + + @Override + public byte[] serializeToByteArray(SortKey[] items) { + try { + DataOutputSerializer output = new DataOutputSerializer(DEFAULT_SORT_KEY_SIZE * items.length); + listSerializer.serialize(Arrays.asList(items), output); + byte[] itemsBytes = output.getSharedBuffer(); + int numBytes = output.length(); + byte[] out = new byte[Integer.BYTES + numBytes]; + ByteArrayUtil.putIntLE(out, 0, numBytes); + System.arraycopy(itemsBytes, 0, out, Integer.BYTES, numBytes); + return out; + } catch (IOException e) { + throw new UncheckedIOException("Failed to serialize sort key", e); + } + } + + @Override + public SortKey[] deserializeFromMemory(Memory mem, long startingOffset, int numItems) { + Preconditions.checkArgument(mem != null, "Invalid input memory: null"); + if (numItems <= 0) { + return new SortKey[0]; + } + + long offset = startingOffset; + Util.checkBounds(offset, Integer.BYTES, mem.getCapacity()); + int numBytes = mem.getInt(offset); + offset += Integer.BYTES; + + Util.checkBounds(offset, numBytes, mem.getCapacity()); + byte[] sortKeyBytes = new byte[numBytes]; + mem.getByteArray(offset, sortKeyBytes, 0, numBytes); + input.setBuffer(sortKeyBytes); + + try { + List sortKeys = listSerializer.deserialize(input); + SortKey[] array = new SortKey[numItems]; + sortKeys.toArray(array); + input.releaseArrays(); + return array; + } catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize sort key sketch", e); + } + } + + @Override + public int sizeOf(SortKey item) { + return serializeToByteArray(item).length; + } + + @Override + public int sizeOf(Memory mem, long offset, int numItems) { + Preconditions.checkArgument(mem != null, "Invalid input memory: null"); + if (numItems <= 0) { + return 0; + } + + Util.checkBounds(offset, Integer.BYTES, mem.getCapacity()); + int numBytes = mem.getInt(offset); + return Integer.BYTES + numBytes; + } + + @Override + public String toString(SortKey item) { + return item.toString(); + } + + @Override + public Class getClassOfT() { + return SortKey.class; + } +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsEvent.java similarity index 58% rename from flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java rename to flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsEvent.java index 852d2157b8cb..f6fcdb8b16ef 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsEvent.java @@ -27,24 +27,39 @@ * statistics in bytes */ @Internal -class DataStatisticsEvent, S> implements OperatorEvent { +class StatisticsEvent implements OperatorEvent { private static final long serialVersionUID = 1L; private final long checkpointId; private final byte[] statisticsBytes; + private final boolean applyImmediately; - private DataStatisticsEvent(long checkpointId, byte[] statisticsBytes) { + private StatisticsEvent(long checkpointId, byte[] statisticsBytes, boolean applyImmediately) { this.checkpointId = checkpointId; this.statisticsBytes = statisticsBytes; + this.applyImmediately = applyImmediately; } - static , S> DataStatisticsEvent create( + static StatisticsEvent createTaskStatisticsEvent( long checkpointId, - DataStatistics dataStatistics, - TypeSerializer> statisticsSerializer) { - return new DataStatisticsEvent<>( + DataStatistics statistics, + TypeSerializer statisticsSerializer) { + // applyImmediately is really only relevant for coordinator to operator event. + // task reported statistics is always merged immediately by the coordinator. + return new StatisticsEvent( checkpointId, - DataStatisticsUtil.serializeDataStatistics(dataStatistics, statisticsSerializer)); + StatisticsUtil.serializeDataStatistics(statistics, statisticsSerializer), + true); + } + + static StatisticsEvent createGlobalStatisticsEvent( + GlobalStatistics statistics, + TypeSerializer statisticsSerializer, + boolean applyImmediately) { + return new StatisticsEvent( + statistics.checkpointId(), + StatisticsUtil.serializeGlobalStatistics(statistics, statisticsSerializer), + applyImmediately); } long checkpointId() { @@ -54,4 +69,8 @@ long checkpointId() { byte[] statisticsBytes() { return statisticsBytes; } + + boolean applyImmediately() { + return applyImmediately; + } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecord.java similarity index 66% rename from flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java rename to flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecord.java index 889e85112e16..bc28df2b0e22 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecord.java @@ -19,6 +19,7 @@ package org.apache.iceberg.flink.sink.shuffle; import java.io.Serializable; +import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.data.RowData; import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; @@ -34,68 +35,66 @@ * After shuffling, a filter and mapper are required to filter out the data distribution weight, * unwrap the object and extract the original record type T. */ -class DataStatisticsOrRecord, S> implements Serializable { +@Internal +public class StatisticsOrRecord implements Serializable { private static final long serialVersionUID = 1L; - private DataStatistics statistics; + private GlobalStatistics statistics; private RowData record; - private DataStatisticsOrRecord(DataStatistics statistics, RowData record) { + private StatisticsOrRecord(GlobalStatistics statistics, RowData record) { Preconditions.checkArgument( record != null ^ statistics != null, "DataStatistics or record, not neither or both"); this.statistics = statistics; this.record = record; } - static , S> DataStatisticsOrRecord fromRecord( - RowData record) { - return new DataStatisticsOrRecord<>(null, record); + static StatisticsOrRecord fromRecord(RowData record) { + return new StatisticsOrRecord(null, record); } - static , S> DataStatisticsOrRecord fromDataStatistics( - DataStatistics statistics) { - return new DataStatisticsOrRecord<>(statistics, null); + static StatisticsOrRecord fromStatistics(GlobalStatistics statistics) { + return new StatisticsOrRecord(statistics, null); } - static , S> DataStatisticsOrRecord reuseRecord( - DataStatisticsOrRecord reuse, TypeSerializer recordSerializer) { + static StatisticsOrRecord reuseRecord( + StatisticsOrRecord reuse, TypeSerializer recordSerializer) { if (reuse.hasRecord()) { return reuse; } else { // not reusable - return DataStatisticsOrRecord.fromRecord(recordSerializer.createInstance()); + return StatisticsOrRecord.fromRecord(recordSerializer.createInstance()); } } - static , S> DataStatisticsOrRecord reuseStatistics( - DataStatisticsOrRecord reuse, - TypeSerializer> statisticsSerializer) { - if (reuse.hasDataStatistics()) { + static StatisticsOrRecord reuseStatistics( + StatisticsOrRecord reuse, TypeSerializer statisticsSerializer) { + if (reuse.hasStatistics()) { return reuse; } else { // not reusable - return DataStatisticsOrRecord.fromDataStatistics(statisticsSerializer.createInstance()); + return StatisticsOrRecord.fromStatistics(statisticsSerializer.createInstance()); } } - boolean hasDataStatistics() { + boolean hasStatistics() { return statistics != null; } - boolean hasRecord() { + public boolean hasRecord() { return record != null; } - DataStatistics dataStatistics() { + GlobalStatistics statistics() { return statistics; } - void dataStatistics(DataStatistics newStatistics) { + void statistics(GlobalStatistics newStatistics) { this.statistics = newStatistics; } - RowData record() { + public RowData record() { return record; } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecordSerializer.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecordSerializer.java similarity index 53% rename from flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecordSerializer.java rename to flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecordSerializer.java index e9a6fa0cbfc5..6e403425938d 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecordSerializer.java +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecordSerializer.java @@ -29,13 +29,12 @@ import org.apache.flink.table.data.RowData; @Internal -class DataStatisticsOrRecordSerializer, S> - extends TypeSerializer> { - private final TypeSerializer> statisticsSerializer; +class StatisticsOrRecordSerializer extends TypeSerializer { + private final TypeSerializer statisticsSerializer; private final TypeSerializer recordSerializer; - DataStatisticsOrRecordSerializer( - TypeSerializer> statisticsSerializer, + StatisticsOrRecordSerializer( + TypeSerializer statisticsSerializer, TypeSerializer recordSerializer) { this.statisticsSerializer = statisticsSerializer; this.recordSerializer = recordSerializer; @@ -48,13 +47,13 @@ public boolean isImmutableType() { @SuppressWarnings("ReferenceEquality") @Override - public TypeSerializer> duplicate() { - TypeSerializer> duplicateStatisticsSerializer = + public TypeSerializer duplicate() { + TypeSerializer duplicateStatisticsSerializer = statisticsSerializer.duplicate(); TypeSerializer duplicateRowDataSerializer = recordSerializer.duplicate(); if ((statisticsSerializer != duplicateStatisticsSerializer) || (recordSerializer != duplicateRowDataSerializer)) { - return new DataStatisticsOrRecordSerializer<>( + return new StatisticsOrRecordSerializer( duplicateStatisticsSerializer, duplicateRowDataSerializer); } else { return this; @@ -62,34 +61,31 @@ public TypeSerializer> duplicate() { } @Override - public DataStatisticsOrRecord createInstance() { + public StatisticsOrRecord createInstance() { // arbitrarily always create RowData value instance - return DataStatisticsOrRecord.fromRecord(recordSerializer.createInstance()); + return StatisticsOrRecord.fromRecord(recordSerializer.createInstance()); } @Override - public DataStatisticsOrRecord copy(DataStatisticsOrRecord from) { + public StatisticsOrRecord copy(StatisticsOrRecord from) { if (from.hasRecord()) { - return DataStatisticsOrRecord.fromRecord(recordSerializer.copy(from.record())); + return StatisticsOrRecord.fromRecord(recordSerializer.copy(from.record())); } else { - return DataStatisticsOrRecord.fromDataStatistics( - statisticsSerializer.copy(from.dataStatistics())); + return StatisticsOrRecord.fromStatistics(statisticsSerializer.copy(from.statistics())); } } @Override - public DataStatisticsOrRecord copy( - DataStatisticsOrRecord from, DataStatisticsOrRecord reuse) { - DataStatisticsOrRecord to; + public StatisticsOrRecord copy(StatisticsOrRecord from, StatisticsOrRecord reuse) { + StatisticsOrRecord to; if (from.hasRecord()) { - to = DataStatisticsOrRecord.reuseRecord(reuse, recordSerializer); + to = StatisticsOrRecord.reuseRecord(reuse, recordSerializer); RowData record = recordSerializer.copy(from.record(), to.record()); to.record(record); } else { - to = DataStatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); - DataStatistics statistics = - statisticsSerializer.copy(from.dataStatistics(), to.dataStatistics()); - to.dataStatistics(statistics); + to = StatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); + GlobalStatistics statistics = statisticsSerializer.copy(from.statistics(), to.statistics()); + to.statistics(statistics); } return to; @@ -101,41 +97,40 @@ public int getLength() { } @Override - public void serialize(DataStatisticsOrRecord statisticsOrRecord, DataOutputView target) + public void serialize(StatisticsOrRecord statisticsOrRecord, DataOutputView target) throws IOException { if (statisticsOrRecord.hasRecord()) { target.writeBoolean(true); recordSerializer.serialize(statisticsOrRecord.record(), target); } else { target.writeBoolean(false); - statisticsSerializer.serialize(statisticsOrRecord.dataStatistics(), target); + statisticsSerializer.serialize(statisticsOrRecord.statistics(), target); } } @Override - public DataStatisticsOrRecord deserialize(DataInputView source) throws IOException { + public StatisticsOrRecord deserialize(DataInputView source) throws IOException { boolean isRecord = source.readBoolean(); if (isRecord) { - return DataStatisticsOrRecord.fromRecord(recordSerializer.deserialize(source)); + return StatisticsOrRecord.fromRecord(recordSerializer.deserialize(source)); } else { - return DataStatisticsOrRecord.fromDataStatistics(statisticsSerializer.deserialize(source)); + return StatisticsOrRecord.fromStatistics(statisticsSerializer.deserialize(source)); } } @Override - public DataStatisticsOrRecord deserialize( - DataStatisticsOrRecord reuse, DataInputView source) throws IOException { - DataStatisticsOrRecord to; + public StatisticsOrRecord deserialize(StatisticsOrRecord reuse, DataInputView source) + throws IOException { + StatisticsOrRecord to; boolean isRecord = source.readBoolean(); if (isRecord) { - to = DataStatisticsOrRecord.reuseRecord(reuse, recordSerializer); + to = StatisticsOrRecord.reuseRecord(reuse, recordSerializer); RowData record = recordSerializer.deserialize(to.record(), source); to.record(record); } else { - to = DataStatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); - DataStatistics statistics = - statisticsSerializer.deserialize(to.dataStatistics(), source); - to.dataStatistics(statistics); + to = StatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); + GlobalStatistics statistics = statisticsSerializer.deserialize(to.statistics(), source); + to.statistics(statistics); } return to; @@ -154,12 +149,11 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public boolean equals(Object obj) { - if (!(obj instanceof DataStatisticsOrRecordSerializer)) { + if (!(obj instanceof StatisticsOrRecordSerializer)) { return false; } - @SuppressWarnings("unchecked") - DataStatisticsOrRecordSerializer other = (DataStatisticsOrRecordSerializer) obj; + StatisticsOrRecordSerializer other = (StatisticsOrRecordSerializer) obj; return Objects.equals(statisticsSerializer, other.statisticsSerializer) && Objects.equals(recordSerializer, other.recordSerializer); } @@ -170,25 +164,22 @@ public int hashCode() { } @Override - public TypeSerializerSnapshot> snapshotConfiguration() { - return new DataStatisticsOrRecordSerializerSnapshot<>(this); + public TypeSerializerSnapshot snapshotConfiguration() { + return new StatisticsOrRecordSerializerSnapshot(this); } - public static class DataStatisticsOrRecordSerializerSnapshot, S> - extends CompositeTypeSerializerSnapshot< - DataStatisticsOrRecord, DataStatisticsOrRecordSerializer> { + public static class StatisticsOrRecordSerializerSnapshot + extends CompositeTypeSerializerSnapshot { private static final int CURRENT_VERSION = 1; - // constructors need to public. Otherwise, Flink state restore would complain - // "The class has no (implicit) public nullary constructor". - @SuppressWarnings("checkstyle:RedundantModifier") - public DataStatisticsOrRecordSerializerSnapshot() { - super(DataStatisticsOrRecordSerializer.class); + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public StatisticsOrRecordSerializerSnapshot() { + super(StatisticsOrRecordSerializer.class); } @SuppressWarnings("checkstyle:RedundantModifier") - public DataStatisticsOrRecordSerializerSnapshot( - DataStatisticsOrRecordSerializer serializer) { + public StatisticsOrRecordSerializerSnapshot(StatisticsOrRecordSerializer serializer) { super(serializer); } @@ -200,7 +191,7 @@ protected int getCurrentOuterSnapshotVersion() { @Override protected TypeSerializer[] getNestedSerializers( - DataStatisticsOrRecordSerializer outerSerializer) { + StatisticsOrRecordSerializer outerSerializer) { return new TypeSerializer[] { outerSerializer.statisticsSerializer, outerSerializer.recordSerializer }; @@ -208,12 +199,12 @@ protected TypeSerializer[] getNestedSerializers( @SuppressWarnings("unchecked") @Override - protected DataStatisticsOrRecordSerializer createOuterSerializerWithNestedSerializers( + protected StatisticsOrRecordSerializer createOuterSerializerWithNestedSerializers( TypeSerializer[] nestedSerializers) { - TypeSerializer> statisticsSerializer = - (TypeSerializer>) nestedSerializers[0]; + TypeSerializer statisticsSerializer = + (TypeSerializer) nestedSerializers[0]; TypeSerializer recordSerializer = (TypeSerializer) nestedSerializers[1]; - return new DataStatisticsOrRecordSerializer<>(statisticsSerializer, recordSerializer); + return new StatisticsOrRecordSerializer(statisticsSerializer, recordSerializer); } } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsType.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsType.java new file mode 100644 index 000000000000..43f72e336e06 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsType.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +/** + * Range distribution requires gathering statistics on the sort keys to determine proper range + * boundaries to distribute/cluster rows before writer operators. + */ +public enum StatisticsType { + /** + * Tracks the data statistics as {@code Map} frequency. It works better for + * low-cardinality scenarios (like country, event_type, etc.) where the cardinalities are in + * hundreds or thousands. + * + *

      + *
    • Pro: accurate measurement on the statistics/weight of every key. + *
    • Con: memory footprint can be large if the key cardinality is high. + *
    + */ + Map, + + /** + * Sample the sort keys via reservoir sampling. Then split the range partitions via range bounds + * from sampled values. It works better for high-cardinality scenarios (like device_id, user_id, + * uuid etc.) where the cardinalities can be in millions or billions. + * + *
      + *
    • Pro: relatively low memory footprint for high-cardinality sort keys. + *
    • Con: non-precise approximation with potentially lower accuracy. + *
    + */ + Sketch, + + /** + * Initially use Map for statistics tracking. If key cardinality turns out to be high, + * automatically switch to sketch sampling. + */ + Auto +} diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java new file mode 100644 index 000000000000..5d48ec57ca49 --- /dev/null +++ b/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.io.UncheckedIOException; +import javax.annotation.Nullable; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; + +class StatisticsUtil { + + private StatisticsUtil() {} + + static DataStatistics createTaskStatistics( + StatisticsType type, int operatorParallelism, int numPartitions) { + if (type == StatisticsType.Map) { + return new MapDataStatistics(); + } else { + return new SketchDataStatistics( + SketchUtil.determineOperatorReservoirSize(operatorParallelism, numPartitions)); + } + } + + static byte[] serializeDataStatistics( + DataStatistics dataStatistics, TypeSerializer statisticsSerializer) { + DataOutputSerializer out = new DataOutputSerializer(64); + try { + statisticsSerializer.serialize(dataStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new UncheckedIOException("Fail to serialize data statistics", e); + } + } + + static DataStatistics deserializeDataStatistics( + byte[] bytes, TypeSerializer statisticsSerializer) { + DataInputDeserializer input = new DataInputDeserializer(bytes, 0, bytes.length); + try { + return statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new UncheckedIOException("Fail to deserialize data statistics", e); + } + } + + static byte[] serializeCompletedStatistics( + CompletedStatistics completedStatistics, + TypeSerializer statisticsSerializer) { + try { + DataOutputSerializer out = new DataOutputSerializer(1024); + statisticsSerializer.serialize(completedStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new UncheckedIOException("Fail to serialize aggregated statistics", e); + } + } + + static CompletedStatistics deserializeCompletedStatistics( + byte[] bytes, TypeSerializer statisticsSerializer) { + try { + DataInputDeserializer input = new DataInputDeserializer(bytes); + return statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new UncheckedIOException("Fail to deserialize aggregated statistics", e); + } + } + + static byte[] serializeGlobalStatistics( + GlobalStatistics globalStatistics, TypeSerializer statisticsSerializer) { + try { + DataOutputSerializer out = new DataOutputSerializer(1024); + statisticsSerializer.serialize(globalStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new UncheckedIOException("Fail to serialize aggregated statistics", e); + } + } + + static GlobalStatistics deserializeGlobalStatistics( + byte[] bytes, TypeSerializer statisticsSerializer) { + try { + DataInputDeserializer input = new DataInputDeserializer(bytes); + return statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new UncheckedIOException("Fail to deserialize aggregated statistics", e); + } + } + + static StatisticsType collectType(StatisticsType config) { + return config == StatisticsType.Sketch ? StatisticsType.Sketch : StatisticsType.Map; + } + + static StatisticsType collectType(StatisticsType config, @Nullable GlobalStatistics statistics) { + if (statistics != null) { + return statistics.type(); + } + + return collectType(config); + } + + static StatisticsType collectType( + StatisticsType config, @Nullable CompletedStatistics statistics) { + if (statistics != null) { + return statistics.type(); + } + + return collectType(config); + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/Fixtures.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/Fixtures.java new file mode 100644 index 000000000000..5910bd685510 --- /dev/null +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/Fixtures.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Comparator; +import java.util.Map; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.flink.RowDataWrapper; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +class Fixtures { + private Fixtures() {} + + public static final int NUM_SUBTASKS = 2; + public static final Schema SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.StringType.get()), + Types.NestedField.optional(2, "number", Types.IntegerType.get())); + public static final RowType ROW_TYPE = RowType.of(new VarCharType(), new IntType()); + public static final TypeSerializer ROW_SERIALIZER = new RowDataSerializer(ROW_TYPE); + public static final RowDataWrapper ROW_WRAPPER = new RowDataWrapper(ROW_TYPE, SCHEMA.asStruct()); + public static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + public static final Comparator SORT_ORDER_COMPARTOR = + SortOrderComparators.forSchema(SCHEMA, SORT_ORDER); + public static final SortKeySerializer SORT_KEY_SERIALIZER = + new SortKeySerializer(SCHEMA, SORT_ORDER); + public static final DataStatisticsSerializer TASK_STATISTICS_SERIALIZER = + new DataStatisticsSerializer(SORT_KEY_SERIALIZER); + public static final GlobalStatisticsSerializer GLOBAL_STATISTICS_SERIALIZER = + new GlobalStatisticsSerializer(SORT_KEY_SERIALIZER); + public static final CompletedStatisticsSerializer COMPLETED_STATISTICS_SERIALIZER = + new CompletedStatisticsSerializer(SORT_KEY_SERIALIZER); + + public static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER); + public static final Map CHAR_KEYS = createCharKeys(); + + public static StatisticsEvent createStatisticsEvent( + StatisticsType type, + TypeSerializer statisticsSerializer, + long checkpointId, + SortKey... keys) { + DataStatistics statistics = createTaskStatistics(type, keys); + return StatisticsEvent.createTaskStatisticsEvent( + checkpointId, statistics, statisticsSerializer); + } + + public static DataStatistics createTaskStatistics(StatisticsType type, SortKey... keys) { + DataStatistics statistics; + if (type == StatisticsType.Sketch) { + statistics = new SketchDataStatistics(128); + } else { + statistics = new MapDataStatistics(); + } + + for (SortKey key : keys) { + statistics.add(key); + } + + return statistics; + } + + private static Map createCharKeys() { + Map keys = Maps.newHashMap(); + for (char c = 'a'; c <= 'z'; ++c) { + String key = Character.toString(c); + SortKey sortKey = SORT_KEY.copy(); + sortKey.set(0, key); + keys.put(key, sortKey); + } + + return keys; + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java deleted file mode 100644 index 739cf764e2a6..000000000000 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import static org.assertj.core.api.Assertions.assertThat; - -import java.util.Map; -import org.apache.iceberg.Schema; -import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.types.Types; -import org.junit.jupiter.api.Test; - -public class TestAggregatedStatistics { - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); - - @Test - public void mergeDataStatisticTest() { - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - - AggregatedStatistics> aggregatedStatistics = - new AggregatedStatistics<>(1, statisticsSerializer); - MapDataStatistics mapDataStatistics1 = new MapDataStatistics(); - mapDataStatistics1.add(keyA); - mapDataStatistics1.add(keyA); - mapDataStatistics1.add(keyB); - aggregatedStatistics.mergeDataStatistic("testOperator", 1, mapDataStatistics1); - MapDataStatistics mapDataStatistics2 = new MapDataStatistics(); - mapDataStatistics2.add(keyA); - aggregatedStatistics.mergeDataStatistic("testOperator", 1, mapDataStatistics2); - assertThat(aggregatedStatistics.dataStatistics().statistics().get(keyA)) - .isEqualTo( - mapDataStatistics1.statistics().get(keyA) + mapDataStatistics2.statistics().get(keyA)); - assertThat(aggregatedStatistics.dataStatistics().statistics().get(keyB)) - .isEqualTo( - mapDataStatistics1.statistics().get(keyB) - + mapDataStatistics2.statistics().getOrDefault(keyB, 0L)); - } -} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java index 0064c91340bf..8322ce683768 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java @@ -18,161 +18,448 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.TASK_STATISTICS_SERIALIZER; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.createStatisticsEvent; import static org.assertj.core.api.Assertions.assertThat; -import java.util.Map; -import org.apache.iceberg.Schema; -import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.types.Types; -import org.junit.jupiter.api.BeforeEach; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class TestAggregatedStatisticsTracker { - private static final int NUM_SUBTASKS = 2; - - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); - private final SortKey keyA = sortKey.copy(); - private final SortKey keyB = sortKey.copy(); - - private AggregatedStatisticsTracker> - aggregatedStatisticsTracker; - - public TestAggregatedStatisticsTracker() { - keyA.set(0, "a"); - keyB.set(0, "b"); - } + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void receiveNewerStatisticsEvent(StatisticsType type) { + AggregatedStatisticsTracker tracker = createTracker(type); - @BeforeEach - public void before() throws Exception { - aggregatedStatisticsTracker = - new AggregatedStatisticsTracker<>("testOperator", statisticsSerializer, NUM_SUBTASKS); - } + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("a")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a")); + } - @Test - public void receiveNewerDataStatisticEvent() { - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - checkpoint1Subtask0DataStatistic.add(keyA); - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint1Subtask0DataStatisticEvent)) - .isNull(); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(1); - - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyA); - DataStatisticsEvent> - checkpoint2Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint2Subtask0DataStatisticEvent)) - .isNull(); - // Checkpoint 2 is newer than checkpoint1, thus dropping in progress statistics for checkpoint1 - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2); + StatisticsEvent checkpoint2Subtask0StatisticsEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 2L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(0, checkpoint2Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + // both checkpoints are tracked + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L, 2L); + aggregation = tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + // checkpoint 1 is completed + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + assertThat(completedStatistics.checkpointId()).isEqualTo(1L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 1L, + CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + + // checkpoint 2 remains + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(2L); + aggregation = tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } } - @Test - public void receiveOlderDataStatisticEventTest() { - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyA); - checkpoint2Subtask0DataStatistic.add(keyB); - checkpoint2Subtask0DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint3Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint3Subtask0DataStatisticEvent)) - .isNull(); - - MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); - checkpoint1Subtask1DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint1Subtask1DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); - // Receive event from old checkpoint, aggregatedStatisticsAggregatorTracker won't return - // completed statistics and in progress statistics won't be updated - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 1, checkpoint1Subtask1DataStatisticEvent)) - .isNull(); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2); + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void receiveOlderStatisticsEventTest(StatisticsType type) { + AggregatedStatisticsTracker tracker = createTracker(type); + + StatisticsEvent checkpoint2Subtask0StatisticsEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 2L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint2Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(2L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + assertThat(completedStatistics).isNull(); + // both checkpoints are tracked + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L, 2L); + aggregation = tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint3Subtask0StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 3L, CHAR_KEYS.get("x")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint3Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L, 2L, 3L); + aggregation = tracker.aggregationsPerCheckpoint().get(3L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("x"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("x")); + } + + StatisticsEvent checkpoint2Subtask1StatisticsEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 2L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint2Subtask1StatisticsEvent); + // checkpoint 1 is cleared along with checkpoint 2. checkpoint 3 remains + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(3L); + aggregation = tracker.aggregationsPerCheckpoint().get(3L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("x"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("x")); + } + + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + assertThat(completedStatistics.checkpointId()).isEqualTo(2L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 2L, + CHAR_KEYS.get("b"), 4L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + } } - @Test - public void receiveCompletedDataStatisticEvent() { - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - checkpoint1Subtask0DataStatistic.add(keyA); - checkpoint1Subtask0DataStatistic.add(keyB); - checkpoint1Subtask0DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint1Subtask0DataStatisticEvent)) - .isNull(); - - MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); - checkpoint1Subtask1DataStatistic.add(keyA); - checkpoint1Subtask1DataStatistic.add(keyA); - checkpoint1Subtask1DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint1Subtask1DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void receiveCompletedStatisticsEvent(StatisticsType type) { + AggregatedStatisticsTracker tracker = createTracker(type); + + StatisticsEvent checkpoint1Subtask0DataStatisticEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0DataStatisticEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint1Subtask1DataStatisticEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + // Receive data statistics from all subtasks at checkpoint 1 - AggregatedStatistics> completedStatistics = - aggregatedStatisticsTracker.updateAndCheckCompletion( - 1, checkpoint1Subtask1DataStatisticEvent); + completedStatistics = + tracker.updateAndCheckCompletion(1, checkpoint1Subtask1DataStatisticEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); assertThat(completedStatistics).isNotNull(); - assertThat(completedStatistics.checkpointId()).isEqualTo(1); - MapDataStatistics globalDataStatistics = - (MapDataStatistics) completedStatistics.dataStatistics(); - assertThat((long) globalDataStatistics.statistics().get(keyA)) - .isEqualTo( - checkpoint1Subtask0DataStatistic.statistics().get(keyA) - + checkpoint1Subtask1DataStatistic.statistics().get(keyA)); - assertThat((long) globalDataStatistics.statistics().get(keyB)) - .isEqualTo( - checkpoint1Subtask0DataStatistic.statistics().get(keyB) - + checkpoint1Subtask1DataStatistic.statistics().get(keyB)); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()) - .isEqualTo(completedStatistics.checkpointId() + 1); - - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyA); - DataStatisticsEvent> - checkpoint2Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint2Subtask0DataStatisticEvent)) - .isNull(); - assertThat(completedStatistics.checkpointId()).isEqualTo(1); - - MapDataStatistics checkpoint2Subtask1DataStatistic = new MapDataStatistics(); - checkpoint2Subtask1DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint2Subtask1DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask1DataStatistic, statisticsSerializer); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + assertThat(completedStatistics.checkpointId()).isEqualTo(1L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 3L, + CHAR_KEYS.get("b"), 3L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint2Subtask0DataStatisticEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("a")); + completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint2Subtask0DataStatisticEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(2L); + aggregation = tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a")); + } + + StatisticsEvent checkpoint2Subtask1DataStatisticEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("b")); // Receive data statistics from all subtasks at checkpoint 2 completedStatistics = - aggregatedStatisticsTracker.updateAndCheckCompletion( - 1, checkpoint2Subtask1DataStatisticEvent); + tracker.updateAndCheckCompletion(1, checkpoint2Subtask1DataStatisticEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); + + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.checkpointId()).isEqualTo(2L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 1L, + CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + } + + @Test + public void coordinatorSwitchToSketchOverThreshold() { + int parallelism = 3; + int downstreamParallelism = 3; + int switchToSketchThreshold = 3; + AggregatedStatisticsTracker tracker = + new AggregatedStatisticsTracker( + "testOperator", + parallelism, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + downstreamParallelism, + StatisticsType.Auto, + switchToSketchThreshold, + null); + + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Map); + assertThat(aggregation.sketchStatistics()).isNull(); + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L)); + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + aggregation = tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0, 1); + // converted to sketch statistics as map size is 4 (over the switch threshold of 3) + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Sketch); + assertThat(aggregation.mapStatistics()).isNull(); + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder( + CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("c"), CHAR_KEYS.get("d")); + StatisticsEvent checkpoint1Subtask2StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + completedStatistics = tracker.updateAndCheckCompletion(2, checkpoint1Subtask2StatisticsEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); assertThat(completedStatistics).isNotNull(); - assertThat(completedStatistics.checkpointId()).isEqualTo(2); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()) - .isEqualTo(completedStatistics.checkpointId() + 1); + assertThat(completedStatistics.type()).isEqualTo(StatisticsType.Sketch); + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + } + + @Test + public void coordinatorMapOperatorSketch() { + int parallelism = 3; + int downstreamParallelism = 3; + AggregatedStatisticsTracker tracker = + new AggregatedStatisticsTracker( + "testOperator", + parallelism, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + downstreamParallelism, + StatisticsType.Auto, + SketchUtil.COORDINATOR_SKETCH_SWITCH_THRESHOLD, + null); + + // first operator event has map statistics + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Map); + assertThat(aggregation.sketchStatistics()).isNull(); + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L)); + + // second operator event contains sketch statistics + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent( + StatisticsType.Sketch, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + aggregation = tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0, 1); + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Sketch); + assertThat(aggregation.mapStatistics()).isNull(); + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder( + CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("c"), CHAR_KEYS.get("d")); + + // third operator event has Map statistics + StatisticsEvent checkpoint1Subtask2StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + completedStatistics = tracker.updateAndCheckCompletion(2, checkpoint1Subtask2StatisticsEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsType.Sketch); + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + } + + private AggregatedStatisticsTracker createTracker(StatisticsType type) { + return new AggregatedStatisticsTracker( + "testOperator", + Fixtures.NUM_SUBTASKS, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + Fixtures.NUM_SUBTASKS, + type, + SketchUtil.COORDINATOR_SKETCH_SWITCH_THRESHOLD, + null); } } diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestCompletedStatisticsSerializer.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestCompletedStatisticsSerializer.java new file mode 100644 index 000000000000..4ee9888934a8 --- /dev/null +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestCompletedStatisticsSerializer.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestCompletedStatisticsSerializer extends SerializerTestBase { + + @Override + protected TypeSerializer createSerializer() { + return Fixtures.COMPLETED_STATISTICS_SERIALIZER; + } + + @Override + protected int getLength() { + return -1; + } + + @Override + protected Class getTypeClass() { + return CompletedStatistics.class; + } + + @Override + protected CompletedStatistics[] getTestData() { + + return new CompletedStatistics[] { + CompletedStatistics.fromKeyFrequency( + 1L, ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)), + CompletedStatistics.fromKeySamples(2L, new SortKey[] {CHAR_KEYS.get("a"), CHAR_KEYS.get("b")}) + }; + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java index 849253564209..a08a8a73e80c 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java @@ -18,9 +18,13 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.NUM_SUBTASKS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.time.Duration; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -28,128 +32,182 @@ import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; import org.apache.flink.util.ExceptionUtils; -import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.types.Types; +import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class TestDataStatisticsCoordinator { private static final String OPERATOR_NAME = "TestCoordinator"; private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 5678L); - private static final int NUM_SUBTASKS = 2; - - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); private EventReceivingTasks receivingTasks; - private DataStatisticsCoordinator> - dataStatisticsCoordinator; @BeforeEach public void before() throws Exception { receivingTasks = EventReceivingTasks.createForRunningTasks(); - dataStatisticsCoordinator = - new DataStatisticsCoordinator<>( - OPERATOR_NAME, - new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), - statisticsSerializer); } - private void tasksReady() throws Exception { - dataStatisticsCoordinator.start(); - setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); + private void tasksReady(DataStatisticsCoordinator coordinator) { + setAllTasksReady(NUM_SUBTASKS, coordinator, receivingTasks); } - @Test - public void testThrowExceptionWhenNotStarted() { - String failureMessage = "The coordinator of TestCoordinator has not started yet."; - - assertThatThrownBy( - () -> - dataStatisticsCoordinator.handleEventFromOperator( - 0, - 0, - DataStatisticsEvent.create(0, new MapDataStatistics(), statisticsSerializer))) - .isInstanceOf(IllegalStateException.class) - .hasMessage(failureMessage); - assertThatThrownBy(() -> dataStatisticsCoordinator.executionAttemptFailed(0, 0, null)) - .isInstanceOf(IllegalStateException.class) - .hasMessage(failureMessage); - assertThatThrownBy(() -> dataStatisticsCoordinator.checkpointCoordinator(0, null)) - .isInstanceOf(IllegalStateException.class) - .hasMessage(failureMessage); + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testThrowExceptionWhenNotStarted(StatisticsType type) throws Exception { + try (DataStatisticsCoordinator dataStatisticsCoordinator = createCoordinator(type)) { + String failureMessage = "The coordinator of TestCoordinator has not started yet."; + assertThatThrownBy( + () -> + dataStatisticsCoordinator.handleEventFromOperator( + 0, + 0, + StatisticsEvent.createTaskStatisticsEvent( + 0, new MapDataStatistics(), Fixtures.TASK_STATISTICS_SERIALIZER))) + .isInstanceOf(IllegalStateException.class) + .hasMessage(failureMessage); + assertThatThrownBy(() -> dataStatisticsCoordinator.executionAttemptFailed(0, 0, null)) + .isInstanceOf(IllegalStateException.class) + .hasMessage(failureMessage); + assertThatThrownBy(() -> dataStatisticsCoordinator.checkpointCoordinator(0, null)) + .isInstanceOf(IllegalStateException.class) + .hasMessage(failureMessage); + } + } + + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testDataStatisticsEventHandling(StatisticsType type) throws Exception { + try (DataStatisticsCoordinator dataStatisticsCoordinator = createCoordinator(type)) { + dataStatisticsCoordinator.start(); + tasksReady(dataStatisticsCoordinator); + + StatisticsEvent checkpoint1Subtask0DataStatisticEvent = + Fixtures.createStatisticsEvent( + type, + Fixtures.TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c")); + StatisticsEvent checkpoint1Subtask1DataStatisticEvent = + Fixtures.createStatisticsEvent( + type, + Fixtures.TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c")); + // Handle events from operators for checkpoint 1 + dataStatisticsCoordinator.handleEventFromOperator( + 0, 0, checkpoint1Subtask0DataStatisticEvent); + dataStatisticsCoordinator.handleEventFromOperator( + 1, 0, checkpoint1Subtask1DataStatisticEvent); + + waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + + Map keyFrequency = + ImmutableMap.of( + CHAR_KEYS.get("a"), 2L, + CHAR_KEYS.get("b"), 3L, + CHAR_KEYS.get("c"), 5L); + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(NUM_SUBTASKS, keyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + + CompletedStatistics completedStatistics = dataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics.checkpointId()).isEqualTo(1L); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()).isEqualTo(keyFrequency); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c")); + } + + GlobalStatistics globalStatistics = dataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics.checkpointId()).isEqualTo(1L); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("b")); + } + } } @Test - public void testDataStatisticsEventHandling() throws Exception { - tasksReady(); - SortKey key = sortKey.copy(); - - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - key.set(0, "a"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "b"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "b"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask0DataStatistic.add(key); - - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); - - MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); - key.set(0, "a"); - checkpoint1Subtask1DataStatistic.add(key); - key.set(0, "b"); - checkpoint1Subtask1DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask1DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask1DataStatistic.add(key); - - DataStatisticsEvent> - checkpoint1Subtask1DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); - - // Handle events from operators for checkpoint 1 - dataStatisticsCoordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0DataStatisticEvent); - dataStatisticsCoordinator.handleEventFromOperator(1, 0, checkpoint1Subtask1DataStatisticEvent); - - waitForCoordinatorToProcessActions(dataStatisticsCoordinator); - - // Verify global data statistics is the aggregation of all subtasks data statistics - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - MapDataStatistics globalDataStatistics = - (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(globalDataStatistics.statistics()) - .containsExactlyInAnyOrderEntriesOf( - ImmutableMap.of( - keyA, 2L, - keyB, 3L, - keyC, 5L)); + public void testRequestGlobalStatisticsEventHandling() throws Exception { + try (DataStatisticsCoordinator dataStatisticsCoordinator = + createCoordinator(StatisticsType.Sketch)) { + dataStatisticsCoordinator.start(); + tasksReady(dataStatisticsCoordinator); + + // receive request before global statistics is ready + dataStatisticsCoordinator.handleEventFromOperator(0, 0, new RequestGlobalStatisticsEvent()); + assertThat(receivingTasks.getSentEventsForSubtask(0)).isEmpty(); + assertThat(receivingTasks.getSentEventsForSubtask(1)).isEmpty(); + + StatisticsEvent checkpoint1Subtask0DataStatisticEvent = + Fixtures.createStatisticsEvent( + StatisticsType.Sketch, Fixtures.TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("a")); + StatisticsEvent checkpoint1Subtask1DataStatisticEvent = + Fixtures.createStatisticsEvent( + StatisticsType.Sketch, Fixtures.TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + // Handle events from operators for checkpoint 1 + dataStatisticsCoordinator.handleEventFromOperator( + 0, 0, checkpoint1Subtask0DataStatisticEvent); + dataStatisticsCoordinator.handleEventFromOperator( + 1, 0, checkpoint1Subtask1DataStatisticEvent); + + waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + Awaitility.await("wait for statistics event") + .pollInterval(Duration.ofMillis(10)) + .atMost(Duration.ofSeconds(10)) + .until(() -> receivingTasks.getSentEventsForSubtask(0).size() == 1); + assertThat(receivingTasks.getSentEventsForSubtask(0).get(0)) + .isInstanceOf(StatisticsEvent.class); + + Awaitility.await("wait for statistics event") + .pollInterval(Duration.ofMillis(10)) + .atMost(Duration.ofSeconds(10)) + .until(() -> receivingTasks.getSentEventsForSubtask(1).size() == 1); + assertThat(receivingTasks.getSentEventsForSubtask(1).get(0)) + .isInstanceOf(StatisticsEvent.class); + + dataStatisticsCoordinator.handleEventFromOperator(1, 0, new RequestGlobalStatisticsEvent()); + + // coordinator should send a response to subtask 1 + Awaitility.await("wait for statistics event") + .pollInterval(Duration.ofMillis(10)) + .atMost(Duration.ofSeconds(10)) + .until(() -> receivingTasks.getSentEventsForSubtask(1).size() == 2); + assertThat(receivingTasks.getSentEventsForSubtask(1).get(0)) + .isInstanceOf(StatisticsEvent.class); + assertThat(receivingTasks.getSentEventsForSubtask(1).get(1)) + .isInstanceOf(StatisticsEvent.class); + } } static void setAllTasksReady( int subtasks, - DataStatisticsCoordinator> dataStatisticsCoordinator, + DataStatisticsCoordinator dataStatisticsCoordinator, EventReceivingTasks receivingTasks) { for (int i = 0; i < subtasks; i++) { dataStatisticsCoordinator.executionAttemptReady( @@ -157,8 +215,7 @@ static void setAllTasksReady( } } - static void waitForCoordinatorToProcessActions( - DataStatisticsCoordinator> coordinator) { + static void waitForCoordinatorToProcessActions(DataStatisticsCoordinator coordinator) { CompletableFuture future = new CompletableFuture<>(); coordinator.callInCoordinatorThread( () -> { @@ -175,4 +232,15 @@ static void waitForCoordinatorToProcessActions( ExceptionUtils.rethrow(ExceptionUtils.stripExecutionException(e)); } } + + private static DataStatisticsCoordinator createCoordinator(StatisticsType type) { + return new DataStatisticsCoordinator( + OPERATOR_NAME, + new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + NUM_SUBTASKS, + type, + 0.0d); + } } diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java index c5216eeb712a..6317f2bfde18 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java @@ -18,6 +18,10 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.TASK_STATISTICS_SERIALIZER; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.createStatisticsEvent; import static org.assertj.core.api.Assertions.assertThat; import java.util.Map; @@ -27,117 +31,157 @@ import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; -import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.types.Types; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class TestDataStatisticsCoordinatorProvider { private static final OperatorID OPERATOR_ID = new OperatorID(); - private static final int NUM_SUBTASKS = 1; - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); - - private DataStatisticsCoordinatorProvider> provider; private EventReceivingTasks receivingTasks; @BeforeEach public void before() { - provider = - new DataStatisticsCoordinatorProvider<>( - "DataStatisticsCoordinatorProvider", OPERATOR_ID, statisticsSerializer); receivingTasks = EventReceivingTasks.createForRunningTasks(); } - @Test - @SuppressWarnings("unchecked") - public void testCheckpointAndReset() throws Exception { - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - SortKey keyD = sortKey.copy(); - keyD.set(0, "c"); - SortKey keyE = sortKey.copy(); - keyE.set(0, "c"); - + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testCheckpointAndReset(StatisticsType type) throws Exception { + DataStatisticsCoordinatorProvider provider = createProvider(type, Fixtures.NUM_SUBTASKS); try (RecreateOnResetOperatorCoordinator coordinator = (RecreateOnResetOperatorCoordinator) - provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID, NUM_SUBTASKS))) { - DataStatisticsCoordinator> dataStatisticsCoordinator = - (DataStatisticsCoordinator>) - coordinator.getInternalCoordinator(); + provider.create( + new MockOperatorCoordinatorContext(OPERATOR_ID, Fixtures.NUM_SUBTASKS))) { + DataStatisticsCoordinator dataStatisticsCoordinator = + (DataStatisticsCoordinator) coordinator.getInternalCoordinator(); // Start the coordinator coordinator.start(); TestDataStatisticsCoordinator.setAllTasksReady( - NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - checkpoint1Subtask0DataStatistic.add(keyA); - checkpoint1Subtask0DataStatistic.add(keyB); - checkpoint1Subtask0DataStatistic.add(keyC); - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); + Fixtures.NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); // Handle events from operators for checkpoint 1 - coordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0DataStatisticEvent); + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("a")); + coordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0StatisticsEvent); TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + coordinator.handleEventFromOperator(1, 0, checkpoint1Subtask1StatisticsEvent); + TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + // Verify checkpoint 1 global data statistics - MapDataStatistics checkpoint1GlobalDataStatistics = - (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(checkpoint1GlobalDataStatistics.statistics()) - .isEqualTo(checkpoint1Subtask0DataStatistic.statistics()); + Map checkpoint1KeyFrequency = + ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L); + MapAssignment checkpoint1MapAssignment = + MapAssignment.fromKeyFrequency( + Fixtures.NUM_SUBTASKS, checkpoint1KeyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + + CompletedStatistics completedStatistics = dataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()).isEqualTo(checkpoint1KeyFrequency); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + + GlobalStatistics globalStatistics = dataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics).isNotNull(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(checkpoint1MapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("a")); + } + byte[] checkpoint1Bytes = waitForCheckpoint(1L, dataStatisticsCoordinator); - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyD); - checkpoint2Subtask0DataStatistic.add(keyE); - checkpoint2Subtask0DataStatistic.add(keyE); - DataStatisticsEvent> - checkpoint2Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - // Handle events from operators for checkpoint 2 - coordinator.handleEventFromOperator(0, 0, checkpoint2Subtask0DataStatisticEvent); + StatisticsEvent checkpoint2Subtask0StatisticsEvent = + createStatisticsEvent( + type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("d"), CHAR_KEYS.get("e")); + coordinator.handleEventFromOperator(0, 0, checkpoint2Subtask0StatisticsEvent); TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + + StatisticsEvent checkpoint2Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("f")); + coordinator.handleEventFromOperator(1, 0, checkpoint2Subtask1StatisticsEvent); + TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + // Verify checkpoint 2 global data statistics - MapDataStatistics checkpoint2GlobalDataStatistics = - (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(checkpoint2GlobalDataStatistics.statistics()) - .isEqualTo(checkpoint2Subtask0DataStatistic.statistics()); + Map checkpoint2KeyFrequency = + ImmutableMap.of(CHAR_KEYS.get("d"), 1L, CHAR_KEYS.get("e"), 1L, CHAR_KEYS.get("f"), 1L); + MapAssignment checkpoint2MapAssignment = + MapAssignment.fromKeyFrequency( + Fixtures.NUM_SUBTASKS, checkpoint2KeyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + completedStatistics = dataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()).isEqualTo(checkpoint2KeyFrequency); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("d"), CHAR_KEYS.get("e"), CHAR_KEYS.get("f")); + } + + globalStatistics = dataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(checkpoint2MapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("e")); + } + waitForCheckpoint(2L, dataStatisticsCoordinator); // Reset coordinator to checkpoint 1 coordinator.resetToCheckpoint(1L, checkpoint1Bytes); - DataStatisticsCoordinator> - restoredDataStatisticsCoordinator = - (DataStatisticsCoordinator>) - coordinator.getInternalCoordinator(); - assertThat(dataStatisticsCoordinator).isNotEqualTo(restoredDataStatisticsCoordinator); + DataStatisticsCoordinator restoredDataStatisticsCoordinator = + (DataStatisticsCoordinator) coordinator.getInternalCoordinator(); + assertThat(dataStatisticsCoordinator).isNotSameAs(restoredDataStatisticsCoordinator); + + completedStatistics = restoredDataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); // Verify restored data statistics - MapDataStatistics restoredAggregateDataStatistics = - (MapDataStatistics) - restoredDataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(restoredAggregateDataStatistics.statistics()) - .isEqualTo(checkpoint1GlobalDataStatistics.statistics()); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + + globalStatistics = restoredDataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics).isNotNull(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(checkpoint1MapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("a")); + } } } - private byte[] waitForCheckpoint( - long checkpointId, - DataStatisticsCoordinator> coordinator) + private byte[] waitForCheckpoint(long checkpointId, DataStatisticsCoordinator coordinator) throws InterruptedException, ExecutionException { CompletableFuture future = new CompletableFuture<>(); coordinator.checkpointCoordinator(checkpointId, future); return future.get(); } + + private static DataStatisticsCoordinatorProvider createProvider( + StatisticsType type, int downstreamParallelism) { + return new DataStatisticsCoordinatorProvider( + "DataStatisticsCoordinatorProvider", + OPERATOR_ID, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + downstreamParallelism, + type, + 0.0); + } } diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java index 5e6f971807ba..c760f1ba96d3 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java @@ -18,22 +18,25 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.datasketches.sampling.ReservoirItemsSketch; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.OperatorStateStore; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway; -import org.apache.flink.runtime.operators.testutils.MockEnvironment; -import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateInitializationContext; @@ -49,102 +52,95 @@ import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; -import org.apache.flink.table.runtime.typeutils.RowDataSerializer; -import org.apache.flink.table.types.logical.IntType; -import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.table.types.logical.VarCharType; -import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.types.Types; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; public class TestDataStatisticsOperator { - private final Schema schema = - new Schema( - Types.NestedField.optional(1, "id", Types.StringType.get()), - Types.NestedField.optional(2, "number", Types.IntegerType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("id").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final RowType rowType = RowType.of(new VarCharType(), new IntType()); - private final TypeSerializer rowSerializer = new RowDataSerializer(rowType); - private final TypeSerializer>> - statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer( - new SortKeySerializer(schema, sortOrder)); - - private DataStatisticsOperator> operator; - - private Environment getTestingEnvironment() { - return new StreamMockEnvironment( - new Configuration(), - new Configuration(), - new ExecutionConfig(), - 1L, - new MockInputSplitProvider(), - 1, - new TestTaskStateManager()); - } + + private Environment env; @BeforeEach public void before() throws Exception { - this.operator = createOperator(); - Environment env = getTestingEnvironment(); - this.operator.setup( - new OneInputStreamTask(env), - new MockStreamConfig(new Configuration(), 1), - new MockOutput<>(Lists.newArrayList())); + this.env = + new StreamMockEnvironment( + new Configuration(), + new Configuration(), + new ExecutionConfig(), + 1L, + new MockInputSplitProvider(), + 1, + new TestTaskStateManager()); } - private DataStatisticsOperator> createOperator() { + private DataStatisticsOperator createOperator(StatisticsType type, int downstreamParallelism) + throws Exception { MockOperatorEventGateway mockGateway = new MockOperatorEventGateway(); - return new DataStatisticsOperator<>( - "testOperator", schema, sortOrder, mockGateway, statisticsSerializer); + return createOperator(type, downstreamParallelism, mockGateway); } - @AfterEach - public void clean() throws Exception { - operator.close(); + private DataStatisticsOperator createOperator( + StatisticsType type, int downstreamParallelism, MockOperatorEventGateway mockGateway) + throws Exception { + DataStatisticsOperator operator = + new DataStatisticsOperator( + "testOperator", + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + mockGateway, + downstreamParallelism, + type); + operator.setup( + new OneInputStreamTask(env), + new MockStreamConfig(new Configuration(), 1), + new MockOutput<>(Lists.newArrayList())); + return operator; } - @Test - public void testProcessElement() throws Exception { - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness = createHarness(this.operator)) { + @SuppressWarnings("unchecked") + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testProcessElement(StatisticsType type) throws Exception { + DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { StateInitializationContext stateContext = getStateContext(); operator.initializeState(stateContext); operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 5))); operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 3))); operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("b"), 1))); - assertThat(operator.localDataStatistics()).isInstanceOf(MapDataStatistics.class); - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - Map expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L); - - MapDataStatistics mapDataStatistics = (MapDataStatistics) operator.localDataStatistics(); - Map statsMap = mapDataStatistics.statistics(); - assertThat(statsMap).hasSize(2); - assertThat(statsMap).containsExactlyInAnyOrderEntriesOf(expectedMap); + DataStatistics localStatistics = operator.localStatistics(); + assertThat(localStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + Map keyFrequency = (Map) localStatistics.result(); + assertThat(keyFrequency) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 1L)); + } else { + ReservoirItemsSketch sketch = + (ReservoirItemsSketch) localStatistics.result(); + assertThat(sketch.getSamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } testHarness.endInput(); } } - @Test - public void testOperatorOutput() throws Exception { - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness = createHarness(this.operator)) { + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testOperatorOutput(StatisticsType type) throws Exception { + DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { testHarness.processElement( new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 2))); testHarness.processElement( @@ -154,8 +150,8 @@ public void testOperatorOutput() throws Exception { List recordsOutput = testHarness.extractOutputValues().stream() - .filter(DataStatisticsOrRecord::hasRecord) - .map(DataStatisticsOrRecord::record) + .filter(StatisticsOrRecord::hasRecord) + .map(StatisticsOrRecord::record) .collect(Collectors.toList()); assertThat(recordsOutput) .containsExactlyInAnyOrderElementsOf( @@ -166,70 +162,172 @@ public void testOperatorOutput() throws Exception { } } - @Test - public void testRestoreState() throws Exception { + private static Stream provideRestoreStateParameters() { + return Stream.of( + Arguments.of(StatisticsType.Map, -1), + Arguments.of(StatisticsType.Map, 0), + Arguments.of(StatisticsType.Map, 1), + Arguments.of(StatisticsType.Sketch, -1), + Arguments.of(StatisticsType.Sketch, 0), + Arguments.of(StatisticsType.Sketch, 1)); + } + + @ParameterizedTest + @MethodSource("provideRestoreStateParameters") + public void testRestoreState(StatisticsType type, int parallelismAdjustment) throws Exception { + Map keyFrequency = + ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 1L, CHAR_KEYS.get("c"), 1L); + SortKey[] rangeBounds = new SortKey[] {CHAR_KEYS.get("a")}; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(2, keyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS); OperatorSubtaskState snapshot; - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness1 = createHarness(this.operator)) { - MapDataStatistics mapDataStatistics = new MapDataStatistics(); - - SortKey key = sortKey.copy(); - key.set(0, "a"); - mapDataStatistics.add(key); - key.set(0, "a"); - mapDataStatistics.add(key); - key.set(0, "b"); - mapDataStatistics.add(key); - key.set(0, "c"); - mapDataStatistics.add(key); - - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - Map expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L, keyC, 1L); - - DataStatisticsEvent> event = - DataStatisticsEvent.create(0, mapDataStatistics, statisticsSerializer); + try (OneInputStreamOperatorTestHarness testHarness1 = + createHarness(operator)) { + GlobalStatistics statistics; + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + statistics = GlobalStatistics.fromMapAssignment(1L, mapAssignment); + } else { + statistics = GlobalStatistics.fromRangeBounds(1L, rangeBounds); + } + + StatisticsEvent event = + StatisticsEvent.createGlobalStatisticsEvent( + statistics, Fixtures.GLOBAL_STATISTICS_SERIALIZER, false); operator.handleOperatorEvent(event); - assertThat(operator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class); - assertThat(operator.globalDataStatistics().statistics()) - .containsExactlyInAnyOrderEntriesOf(expectedMap); + + GlobalStatistics globalStatistics = operator.globalStatistics(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment); + assertThat(globalStatistics.rangeBounds()).isNull(); + } else { + assertThat(globalStatistics.mapAssignment()).isNull(); + assertThat(globalStatistics.rangeBounds()).isEqualTo(rangeBounds); + } + snapshot = testHarness1.snapshot(1L, 0); } // Use the snapshot to initialize state for another new operator and then verify that the global // statistics for the new operator is same as before - DataStatisticsOperator> restoredOperator = - createOperator(); - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness2 = new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) { + MockOperatorEventGateway spyGateway = Mockito.spy(new MockOperatorEventGateway()); + DataStatisticsOperator restoredOperator = + createOperator(type, Fixtures.NUM_SUBTASKS + parallelismAdjustment, spyGateway); + try (OneInputStreamOperatorTestHarness testHarness2 = + new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) { testHarness2.setup(); testHarness2.initializeState(snapshot); - assertThat(restoredOperator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class); - // restored RowData is BinaryRowData. convert to GenericRowData for comparison - Map restoredStatistics = Maps.newHashMap(); - restoredStatistics.putAll(restoredOperator.globalDataStatistics().statistics()); + GlobalStatistics globalStatistics = restoredOperator.globalStatistics(); + // global statistics is always restored and used initially even if + // downstream parallelism changed. + assertThat(globalStatistics).isNotNull(); + // request is always sent to coordinator during initialization. + // coordinator would respond with a new global statistics that + // has range bound recomputed with new parallelism. + verify(spyGateway).sendEventToCoordinator(any(RequestGlobalStatisticsEvent.class)); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment); + assertThat(globalStatistics.rangeBounds()).isNull(); + } else { + assertThat(globalStatistics.mapAssignment()).isNull(); + assertThat(globalStatistics.rangeBounds()).isEqualTo(rangeBounds); + } + } + } - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - Map expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L, keyC, 1L); + @SuppressWarnings("unchecked") + @Test + public void testMigrationWithLocalStatsOverThreshold() throws Exception { + DataStatisticsOperator operator = createOperator(StatisticsType.Auto, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { + StateInitializationContext stateContext = getStateContext(); + operator.initializeState(stateContext); + + // add rows with unique keys + for (int i = 0; i < SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD; ++i) { + operator.processElement( + new StreamRecord<>(GenericRowData.of(StringData.fromString(String.valueOf(i)), i))); + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Map); + assertThat((Map) operator.localStatistics().result()).hasSize(i + 1); + } + + // one more item should trigger the migration to sketch stats + operator.processElement( + new StreamRecord<>(GenericRowData.of(StringData.fromString("key-trigger-migration"), 1))); + + int reservoirSize = + SketchUtil.determineOperatorReservoirSize(Fixtures.NUM_SUBTASKS, Fixtures.NUM_SUBTASKS); + + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Sketch); + ReservoirItemsSketch sketch = + (ReservoirItemsSketch) operator.localStatistics().result(); + assertThat(sketch.getK()).isEqualTo(reservoirSize); + assertThat(sketch.getN()).isEqualTo(SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD + 1); + // reservoir not full yet + assertThat(sketch.getN()).isLessThan(reservoirSize); + assertThat(sketch.getSamples()).hasSize((int) sketch.getN()); + + // add more items to saturate the reservoir + for (int i = 0; i < reservoirSize; ++i) { + operator.processElement( + new StreamRecord<>(GenericRowData.of(StringData.fromString(String.valueOf(i)), i))); + } + + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Sketch); + sketch = (ReservoirItemsSketch) operator.localStatistics().result(); + assertThat(sketch.getK()).isEqualTo(reservoirSize); + assertThat(sketch.getN()) + .isEqualTo(SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD + 1 + reservoirSize); + // reservoir is full now + assertThat(sketch.getN()).isGreaterThan(reservoirSize); + assertThat(sketch.getSamples()).hasSize(reservoirSize); + + testHarness.endInput(); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testMigrationWithGlobalSketchStatistics() throws Exception { + DataStatisticsOperator operator = createOperator(StatisticsType.Auto, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { + StateInitializationContext stateContext = getStateContext(); + operator.initializeState(stateContext); - assertThat(restoredStatistics).containsExactlyInAnyOrderEntriesOf(expectedMap); + // started with Map stype + operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 1))); + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Map); + assertThat((Map) operator.localStatistics().result()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L)); + + // received global statistics with sketch type + GlobalStatistics globalStatistics = + GlobalStatistics.fromRangeBounds( + 1L, new SortKey[] {CHAR_KEYS.get("c"), CHAR_KEYS.get("f")}); + operator.handleOperatorEvent( + StatisticsEvent.createGlobalStatisticsEvent( + globalStatistics, Fixtures.GLOBAL_STATISTICS_SERIALIZER, false)); + + int reservoirSize = + SketchUtil.determineOperatorReservoirSize(Fixtures.NUM_SUBTASKS, Fixtures.NUM_SUBTASKS); + + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Sketch); + ReservoirItemsSketch sketch = + (ReservoirItemsSketch) operator.localStatistics().result(); + assertThat(sketch.getK()).isEqualTo(reservoirSize); + assertThat(sketch.getN()).isEqualTo(1); + assertThat(sketch.getSamples()).isEqualTo(new SortKey[] {CHAR_KEYS.get("a")}); + + testHarness.endInput(); } } private StateInitializationContext getStateContext() throws Exception { - MockEnvironment env = new MockEnvironmentBuilder().build(); AbstractStateBackend abstractStateBackend = new HashMapStateBackend(); CloseableRegistry cancelStreamRegistry = new CloseableRegistry(); OperatorStateStore operatorStateStore = @@ -238,17 +336,14 @@ private StateInitializationContext getStateContext() throws Exception { return new StateInitializationContextImpl(null, operatorStateStore, null, null, null); } - private OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - createHarness( - final DataStatisticsOperator> - dataStatisticsOperator) - throws Exception { - - OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - harness = new OneInputStreamOperatorTestHarness<>(dataStatisticsOperator, 1, 1, 0); - harness.setup(new DataStatisticsOrRecordSerializer<>(statisticsSerializer, rowSerializer)); + private OneInputStreamOperatorTestHarness createHarness( + DataStatisticsOperator dataStatisticsOperator) throws Exception { + OneInputStreamOperatorTestHarness harness = + new OneInputStreamOperatorTestHarness<>( + dataStatisticsOperator, Fixtures.NUM_SUBTASKS, Fixtures.NUM_SUBTASKS, 0); + harness.setup( + new StatisticsOrRecordSerializer( + Fixtures.GLOBAL_STATISTICS_SERIALIZER, Fixtures.ROW_SERIALIZER)); harness.open(); return harness; } diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsSerializer.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsSerializer.java new file mode 100644 index 000000000000..59ce6df05d9d --- /dev/null +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsSerializer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +public class TestDataStatisticsSerializer extends SerializerTestBase { + @Override + protected TypeSerializer createSerializer() { + return Fixtures.TASK_STATISTICS_SERIALIZER; + } + + @Override + protected int getLength() { + return -1; + } + + @Override + protected Class getTypeClass() { + return DataStatistics.class; + } + + @Override + protected DataStatistics[] getTestData() { + return new DataStatistics[] { + new MapDataStatistics(), + Fixtures.createTaskStatistics( + StatisticsType.Map, CHAR_KEYS.get("a"), CHAR_KEYS.get("a"), CHAR_KEYS.get("b")), + new SketchDataStatistics(128), + Fixtures.createTaskStatistics( + StatisticsType.Sketch, CHAR_KEYS.get("a"), CHAR_KEYS.get("a"), CHAR_KEYS.get("b")) + }; + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsSerializer.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsSerializer.java new file mode 100644 index 000000000000..7afaf239c668 --- /dev/null +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsSerializer.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestGlobalStatisticsSerializer extends SerializerTestBase { + + @Override + protected TypeSerializer createSerializer() { + return Fixtures.GLOBAL_STATISTICS_SERIALIZER; + } + + @Override + protected int getLength() { + return -1; + } + + @Override + protected Class getTypeClass() { + return GlobalStatistics.class; + } + + @Override + protected GlobalStatistics[] getTestData() { + return new GlobalStatistics[] { + GlobalStatistics.fromMapAssignment( + 1L, + MapAssignment.fromKeyFrequency( + Fixtures.NUM_SUBTASKS, + ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L), + 0.0d, + SORT_ORDER_COMPARTOR)), + GlobalStatistics.fromRangeBounds(2L, new SortKey[] {CHAR_KEYS.get("a"), CHAR_KEYS.get("b")}) + }; + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java index be2beeebc93c..8a25c7ad9898 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java @@ -18,74 +18,50 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.ROW_WRAPPER; import static org.assertj.core.api.Assertions.assertThat; import java.util.Map; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.StringData; -import org.apache.flink.table.types.logical.RowType; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.flink.FlinkSchemaUtil; -import org.apache.iceberg.flink.RowDataWrapper; -import org.apache.iceberg.flink.TestFixtures; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; public class TestMapDataStatistics { - private final SortOrder sortOrder = SortOrder.builderFor(TestFixtures.SCHEMA).asc("data").build(); - private final SortKey sortKey = new SortKey(TestFixtures.SCHEMA, sortOrder); - private final RowType rowType = FlinkSchemaUtil.convert(TestFixtures.SCHEMA); - private final RowDataWrapper rowWrapper = - new RowDataWrapper(rowType, TestFixtures.SCHEMA.asStruct()); - + @SuppressWarnings("unchecked") @Test public void testAddsAndGet() { MapDataStatistics dataStatistics = new MapDataStatistics(); - GenericRowData reusedRow = - GenericRowData.of(StringData.fromString("a"), 1, StringData.fromString("2023-06-20")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + GenericRowData reusedRow = GenericRowData.of(StringData.fromString("a"), 1); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("b")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("c")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("b")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("a")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("b")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); - - Map actual = dataStatistics.statistics(); - - rowWrapper.wrap( - GenericRowData.of(StringData.fromString("a"), 1, StringData.fromString("2023-06-20"))); - sortKey.wrap(rowWrapper); - SortKey keyA = sortKey.copy(); - - rowWrapper.wrap( - GenericRowData.of(StringData.fromString("b"), 1, StringData.fromString("2023-06-20"))); - sortKey.wrap(rowWrapper); - SortKey keyB = sortKey.copy(); - - rowWrapper.wrap( - GenericRowData.of(StringData.fromString("c"), 1, StringData.fromString("2023-06-20"))); - sortKey.wrap(rowWrapper); - SortKey keyC = sortKey.copy(); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); - Map expected = ImmutableMap.of(keyA, 2L, keyB, 3L, keyC, 1L); + Map actual = (Map) dataStatistics.result(); + Map expected = + ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 3L, CHAR_KEYS.get("c"), 1L); assertThat(actual).isEqualTo(expected); } } diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java index e6726e7db785..d5a0bebc74e7 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; import static org.assertj.core.api.Assertions.assertThat; import java.util.List; @@ -64,65 +65,60 @@ private static SortKey[] initSortKeys() { } // Total weight is 800 - private final MapDataStatistics mapDataStatistics = - new MapDataStatistics( - ImmutableMap.of( - SORT_KEYS[0], - 350L, - SORT_KEYS[1], - 230L, - SORT_KEYS[2], - 120L, - SORT_KEYS[3], - 40L, - SORT_KEYS[4], - 10L, - SORT_KEYS[5], - 10L, - SORT_KEYS[6], - 10L, - SORT_KEYS[7], - 10L, - SORT_KEYS[8], - 10L, - SORT_KEYS[9], - 10L)); + private final Map mapStatistics = + ImmutableMap.of( + SORT_KEYS[0], + 350L, + SORT_KEYS[1], + 230L, + SORT_KEYS[2], + 120L, + SORT_KEYS[3], + 40L, + SORT_KEYS[4], + 10L, + SORT_KEYS[5], + 10L, + SORT_KEYS[6], + 10L, + SORT_KEYS[7], + 10L, + SORT_KEYS[8], + 10L, + SORT_KEYS[9], + 10L); @Test public void testEvenlyDividableNoClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 0.0); int numPartitions = 8; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 0.0, SORT_ORDER_COMPARTOR); // each task should get targeted weight of 100 (=800/8) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(100L, 100L, 100L, 50L), 0L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(3, 4, 5), ImmutableList.of(50L, 100L, 80L), 0L), + new KeyAssignment(ImmutableList.of(3, 4, 5), ImmutableList.of(50L, 100L, 80L), 0L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(5, 6), ImmutableList.of(20L, 100L), 0L), + new KeyAssignment(ImmutableList.of(5, 6), ImmutableList.of(20L, 100L), 0L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(40L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(40L), 0L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L)); + assertThat(mapAssignment).isEqualTo(new MapAssignment(numPartitions, expectedAssignment)); // key: subtask id // value pair: first is the assigned weight, second is the number of assigned keys @@ -144,19 +140,20 @@ public void testEvenlyDividableNoClosingFileCost() { Pair.of(100L, 1), 7, Pair.of(100L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0); } @Test public void testEvenlyDividableWithClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 5.0); int numPartitions = 8; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 5.0, SORT_ORDER_COMPARTOR); // target subtask weight is 100 before close file cost factored in. // close file cost is 5 = 5% * 100. @@ -165,35 +162,30 @@ public void testEvenlyDividableWithClosingFileCost() { // close-cost: 20, 15, 10, 5, 5, 5, 5, 5, 5, 5 // after: 370, 245, 130, 45, 15, 15, 15, 15, 15, 15 // target subtask weight with close cost per subtask is 110 (880/8) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(110L, 110L, 110L, 40L), 5L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(3, 4, 5), ImmutableList.of(70L, 110L, 65L), 5L), + new KeyAssignment(ImmutableList.of(3, 4, 5), ImmutableList.of(70L, 110L, 65L), 5L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(5, 6), ImmutableList.of(45L, 85L), 5L), + new KeyAssignment(ImmutableList.of(5, 6), ImmutableList.of(45L, 85L), 5L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(6, 7), ImmutableList.of(25L, 20L), 5L), + new KeyAssignment(ImmutableList.of(6, 7), ImmutableList.of(25L, 20L), 5L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L)); + assertThat(mapAssignment.keyAssignments()).isEqualTo(expectedAssignment); // key: subtask id // value pair: first is the assigned weight (excluding close file cost) for the subtask, @@ -216,51 +208,48 @@ public void testEvenlyDividableWithClosingFileCost() { Pair.of(100L, 2), 7, Pair.of(75L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0); } @Test public void testNonDividableNoClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 0.0); int numPartitions = 9; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 0.0, SORT_ORDER_COMPARTOR); // before: 350, 230, 120, 40, 10, 10, 10, 10, 10, 10 // each task should get targeted weight of 89 = ceiling(800/9) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(89L, 89L, 89L, 83L), 0L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(3, 4, 5, 6), ImmutableList.of(6L, 89L, 89L, 46L), 0L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(6, 7), ImmutableList.of(43L, 77L), 0L), + new KeyAssignment(ImmutableList.of(6, 7), ImmutableList.of(43L, 77L), 0L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(7, 8), ImmutableList.of(12L, 28L), 0L), + new KeyAssignment(ImmutableList.of(7, 8), ImmutableList.of(12L, 28L), 0L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L)); + assertThat(mapAssignment.keyAssignments()).isEqualTo(expectedAssignment); // key: subtask id // value pair: first is the assigned weight, second is the number of assigned keys @@ -284,19 +273,20 @@ public void testNonDividableNoClosingFileCost() { Pair.of(89L, 2), 8, Pair.of(88L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0); } @Test public void testNonDividableWithClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 5.0); int numPartitions = 9; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 5.0, SORT_ORDER_COMPARTOR); // target subtask weight is 89 before close file cost factored in. // close file cost is 5 (= 5% * 89) per file. @@ -305,35 +295,31 @@ public void testNonDividableWithClosingFileCost() { // close-cost: 20, 15, 10, 5, 5, 5, 5, 5, 5, 5 // after: 370, 245, 130, 45, 15, 15, 15, 15, 15, 15 // target subtask weight per subtask is 98 ceiling(880/9) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(98L, 98L, 98L, 76L), 5L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(3, 4, 5, 6), ImmutableList.of(22L, 98L, 98L, 27L), 5L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(6, 7), ImmutableList.of(71L, 59L), 5L), + new KeyAssignment(ImmutableList.of(6, 7), ImmutableList.of(71L, 59L), 5L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(7, 8), ImmutableList.of(39L, 6L), 5L), + new KeyAssignment(ImmutableList.of(7, 8), ImmutableList.of(39L, 6L), 5L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L)); + assertThat(mapAssignment.keyAssignments()).isEqualTo(expectedAssignment); // key: subtask id // value pair: first is the assigned weight for the subtask, second is the number of keys @@ -358,40 +344,39 @@ public void testNonDividableWithClosingFileCost() { Pair.of(88L, 2), 8, Pair.of(61L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); // drift threshold is high for non-dividable scenario with close cost validatePartitionResults(expectedAssignmentInfo, partitionResults, 10.0); } private static Map>> runPartitioner( - MapRangePartitioner partitioner, int numPartitions) { + MapRangePartitioner partitioner, int numPartitions, Map mapStatistics) { // The Map key is the subtaskId. // For the map value pair, the first element is the count of assigned and // the second element of Set is for the set of assigned keys. Map>> partitionResults = Maps.newHashMap(); - partitioner - .mapStatistics() - .forEach( - (sortKey, weight) -> { - String key = sortKey.get(0, String.class); - // run 100x times of the weight - long iterations = weight * 100; - for (int i = 0; i < iterations; ++i) { - RowData rowData = - GenericRowData.of( - StringData.fromString(key), 1, StringData.fromString("2023-06-20")); - int subtaskId = partitioner.partition(rowData, numPartitions); - partitionResults.computeIfAbsent( - subtaskId, k -> Pair.of(new AtomicLong(0), Sets.newHashSet())); - Pair> pair = partitionResults.get(subtaskId); - pair.first().incrementAndGet(); - pair.second().add(rowData); - } - }); + mapStatistics.forEach( + (sortKey, weight) -> { + String key = sortKey.get(0, String.class); + // run 100x times of the weight + long iterations = weight * 100; + for (int i = 0; i < iterations; ++i) { + RowData rowData = + GenericRowData.of( + StringData.fromString(key), 1, StringData.fromString("2023-06-20")); + int subtaskId = partitioner.partition(rowData, numPartitions); + partitionResults.computeIfAbsent( + subtaskId, k -> Pair.of(new AtomicLong(0), Sets.newHashSet())); + Pair> pair = partitionResults.get(subtaskId); + pair.first().incrementAndGet(); + pair.second().add(rowData); + } + }); return partitionResults; } diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchDataStatistics.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchDataStatistics.java new file mode 100644 index 000000000000..396bfae2f13c --- /dev/null +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchDataStatistics.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.ROW_WRAPPER; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.StringData; +import org.apache.iceberg.SortKey; +import org.junit.jupiter.api.Test; + +public class TestSketchDataStatistics { + @SuppressWarnings("unchecked") + @Test + public void testAddsAndGet() { + SketchDataStatistics dataStatistics = new SketchDataStatistics(128); + + GenericRowData reusedRow = GenericRowData.of(StringData.fromString("a"), 1); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + reusedRow.setField(0, StringData.fromString("b")); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + reusedRow.setField(0, StringData.fromString("c")); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + reusedRow.setField(0, StringData.fromString("b")); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + ReservoirItemsSketch actual = (ReservoirItemsSketch) dataStatistics.result(); + assertThat(actual.getSamples()) + .isEqualTo( + new SortKey[] { + CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("c"), CHAR_KEYS.get("b") + }); + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchUtil.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchUtil.java new file mode 100644 index 000000000000..31dae5c76aeb --- /dev/null +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchUtil.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.SortKey; +import org.junit.jupiter.api.Test; + +public class TestSketchUtil { + @Test + public void testCoordinatorReservoirSize() { + // adjusted to over min threshold of 10_000 and is divisible by number of partitions (3) + assertThat(SketchUtil.determineCoordinatorReservoirSize(3)).isEqualTo(10_002); + // adjust to multiplier of 100 + assertThat(SketchUtil.determineCoordinatorReservoirSize(123)).isEqualTo(123_00); + // adjusted to below max threshold of 1_000_000 and is divisible by number of partitions (3) + assertThat(SketchUtil.determineCoordinatorReservoirSize(10_123)) + .isEqualTo(1_000_000 - (1_000_000 % 10_123)); + } + + @Test + public void testOperatorReservoirSize() { + assertThat(SketchUtil.determineOperatorReservoirSize(5, 3)) + .isEqualTo((10_002 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 5); + assertThat(SketchUtil.determineOperatorReservoirSize(123, 123)) + .isEqualTo((123_00 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 123); + assertThat(SketchUtil.determineOperatorReservoirSize(256, 123)) + .isEqualTo( + (int) Math.ceil((double) (123_00 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 256)); + assertThat(SketchUtil.determineOperatorReservoirSize(5_120, 10_123)) + .isEqualTo( + (int) Math.ceil((double) (992_054 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 5_120)); + } + + @Test + public void testRangeBoundsOneChannel() { + assertThat( + SketchUtil.rangeBounds( + 1, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f") + })) + .isEmpty(); + } + + @Test + public void testRangeBoundsDivisible() { + assertThat( + SketchUtil.rangeBounds( + 3, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f") + })) + .containsExactly(CHAR_KEYS.get("b"), CHAR_KEYS.get("d")); + } + + @Test + public void testRangeBoundsNonDivisible() { + // step is 3 = ceiling(11/4) + assertThat( + SketchUtil.rangeBounds( + 4, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f"), + CHAR_KEYS.get("g"), + CHAR_KEYS.get("h"), + CHAR_KEYS.get("i"), + CHAR_KEYS.get("j"), + CHAR_KEYS.get("k"), + })) + .containsExactly(CHAR_KEYS.get("c"), CHAR_KEYS.get("f"), CHAR_KEYS.get("i")); + } + + @Test + public void testRangeBoundsSkipDuplicates() { + // step is 3 = ceiling(11/4) + assertThat( + SketchUtil.rangeBounds( + 4, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("g"), + CHAR_KEYS.get("h"), + CHAR_KEYS.get("i"), + CHAR_KEYS.get("j"), + CHAR_KEYS.get("k"), + })) + // skipped duplicate c's + .containsExactly(CHAR_KEYS.get("c"), CHAR_KEYS.get("g"), CHAR_KEYS.get("j")); + } +} diff --git a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java index 291302aef486..54cceae6e55b 100644 --- a/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java +++ b/flink/v1.17/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java @@ -18,14 +18,24 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; import org.apache.iceberg.NullOrder; import org.apache.iceberg.Schema; import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortKey; import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.flink.DataGenerator; import org.apache.iceberg.flink.DataGenerators; +import org.apache.iceberg.flink.RowDataWrapper; +import org.junit.jupiter.api.Test; public class TestSortKeySerializerPrimitives extends TestSortKeySerializerBase { private final DataGenerator generator = new DataGenerators.Primitives(); @@ -54,4 +64,27 @@ protected SortOrder sortOrder() { protected GenericRowData rowData() { return generator.generateFlinkRowData(); } + + @Test + public void testSerializationSize() throws Exception { + RowData rowData = + GenericRowData.of(StringData.fromString("550e8400-e29b-41d4-a716-446655440000"), 1L); + RowDataWrapper rowDataWrapper = + new RowDataWrapper(Fixtures.ROW_TYPE, Fixtures.SCHEMA.asStruct()); + StructLike struct = rowDataWrapper.wrap(rowData); + SortKey sortKey = Fixtures.SORT_KEY.copy(); + sortKey.wrap(struct); + SortKeySerializer serializer = new SortKeySerializer(Fixtures.SCHEMA, Fixtures.SORT_ORDER); + DataOutputSerializer output = new DataOutputSerializer(1024); + serializer.serialize(sortKey, output); + byte[] serializedBytes = output.getCopyOfBuffer(); + assertThat(serializedBytes.length) + .as( + "Serialized bytes for sort key should be 38 bytes (34 UUID text + 4 byte integer of string length") + .isEqualTo(38); + + DataInputDeserializer input = new DataInputDeserializer(serializedBytes); + SortKey deserialized = serializer.deserialize(input); + assertThat(deserialized).isEqualTo(sortKey); + } } diff --git a/flink/v1.18/build.gradle b/flink/v1.18/build.gradle index f06318af83a3..aac01c9c6931 100644 --- a/flink/v1.18/build.gradle +++ b/flink/v1.18/build.gradle @@ -66,6 +66,8 @@ project(":iceberg-flink:iceberg-flink-${flinkMajorVersion}") { exclude group: 'org.slf4j' } + implementation libs.datasketches + testImplementation libs.flink118.connector.test.utils testImplementation libs.flink118.core testImplementation libs.flink118.runtime diff --git a/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java b/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java index 3b2c74fd6ece..a9ad386a5a4a 100644 --- a/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java +++ b/flink/v1.18/flink/src/jmh/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitionerBenchmark.java @@ -19,6 +19,7 @@ package org.apache.iceberg.flink.sink.shuffle; import java.nio.charset.StandardCharsets; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.NavigableMap; @@ -28,6 +29,8 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; @@ -67,6 +70,8 @@ public class MapRangePartitionerBenchmark { Types.NestedField.required(9, "name9", Types.StringType.get())); private static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + private static final Comparator SORT_ORDER_COMPARTOR = + SortOrderComparators.forSchema(SCHEMA, SORT_ORDER); private static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER); private MapRangePartitioner partitioner; @@ -83,10 +88,11 @@ public void setupBenchmark() { mapStatistics.put(sortKey, weight); }); - MapDataStatistics dataStatistics = new MapDataStatistics(mapStatistics); + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(2, mapStatistics, 0.0, SORT_ORDER_COMPARTOR); this.partitioner = new MapRangePartitioner( - SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), dataStatistics, 2); + SCHEMA, SortOrder.builderFor(SCHEMA).asc("id").build(), mapAssignment); List keys = Lists.newArrayList(weights.keySet().iterator()); long[] weightsCDF = new long[keys.size()]; diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java deleted file mode 100644 index 157f04b8b0ed..000000000000 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatistics.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import java.io.Serializable; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; - -/** - * AggregatedStatistics is used by {@link DataStatisticsCoordinator} to collect {@link - * DataStatistics} from {@link DataStatisticsOperator} subtasks for specific checkpoint. It stores - * the merged {@link DataStatistics} result from all reported subtasks. - */ -class AggregatedStatistics, S> implements Serializable { - - private final long checkpointId; - private final DataStatistics dataStatistics; - - AggregatedStatistics(long checkpoint, TypeSerializer> statisticsSerializer) { - this.checkpointId = checkpoint; - this.dataStatistics = statisticsSerializer.createInstance(); - } - - AggregatedStatistics(long checkpoint, DataStatistics dataStatistics) { - this.checkpointId = checkpoint; - this.dataStatistics = dataStatistics; - } - - long checkpointId() { - return checkpointId; - } - - DataStatistics dataStatistics() { - return dataStatistics; - } - - void mergeDataStatistic(String operatorName, long eventCheckpointId, D eventDataStatistics) { - Preconditions.checkArgument( - checkpointId == eventCheckpointId, - "Received unexpected event from operator %s checkpoint %s. Expected checkpoint %s", - operatorName, - eventCheckpointId, - checkpointId); - dataStatistics.merge(eventDataStatistics); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("checkpointId", checkpointId) - .add("dataStatistics", dataStatistics) - .toString(); - } -} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java index e8ff61dbeb27..338523b7b074 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/AggregatedStatisticsTracker.java @@ -18,116 +18,238 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Map; +import java.util.NavigableMap; import java.util.Set; +import javax.annotation.Nullable; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.datasketches.sampling.ReservoirItemsUnion; import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.SortOrder; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * AggregatedStatisticsTracker is used by {@link DataStatisticsCoordinator} to track the in progress - * {@link AggregatedStatistics} received from {@link DataStatisticsOperator} subtasks for specific - * checkpoint. + * AggregatedStatisticsTracker tracks the statistics aggregation received from {@link + * DataStatisticsOperator} subtasks for every checkpoint. */ -class AggregatedStatisticsTracker, S> { +class AggregatedStatisticsTracker { private static final Logger LOG = LoggerFactory.getLogger(AggregatedStatisticsTracker.class); - private static final double ACCEPT_PARTIAL_AGGR_THRESHOLD = 90; + private final String operatorName; - private final TypeSerializer> statisticsSerializer; private final int parallelism; - private final Set inProgressSubtaskSet; - private volatile AggregatedStatistics inProgressStatistics; + private final TypeSerializer statisticsSerializer; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final int switchToSketchThreshold; + private final NavigableMap aggregationsPerCheckpoint; + + private CompletedStatistics completedStatistics; AggregatedStatisticsTracker( String operatorName, - TypeSerializer> statisticsSerializer, - int parallelism) { + int parallelism, + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, + StatisticsType statisticsType, + int switchToSketchThreshold, + @Nullable CompletedStatistics restoredStatistics) { this.operatorName = operatorName; - this.statisticsSerializer = statisticsSerializer; this.parallelism = parallelism; - this.inProgressSubtaskSet = Sets.newHashSet(); + this.statisticsSerializer = + new DataStatisticsSerializer(new SortKeySerializer(schema, sortOrder)); + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + this.switchToSketchThreshold = switchToSketchThreshold; + this.completedStatistics = restoredStatistics; + + this.aggregationsPerCheckpoint = Maps.newTreeMap(); } - AggregatedStatistics updateAndCheckCompletion( - int subtask, DataStatisticsEvent event) { + CompletedStatistics updateAndCheckCompletion(int subtask, StatisticsEvent event) { long checkpointId = event.checkpointId(); + LOG.debug( + "Handling statistics event from subtask {} of operator {} for checkpoint {}", + subtask, + operatorName, + checkpointId); - if (inProgressStatistics != null && inProgressStatistics.checkpointId() > checkpointId) { + if (completedStatistics != null && completedStatistics.checkpointId() > checkpointId) { LOG.info( - "Expect data statistics for operator {} checkpoint {}, but receive event from older checkpoint {}. Ignore it.", + "Ignore stale statistics event from operator {} subtask {} for older checkpoint {}. " + + "Was expecting data statistics from checkpoint higher than {}", operatorName, - inProgressStatistics.checkpointId(), - checkpointId); + subtask, + checkpointId, + completedStatistics.checkpointId()); return null; } - AggregatedStatistics completedStatistics = null; - if (inProgressStatistics != null && inProgressStatistics.checkpointId() < checkpointId) { - if ((double) inProgressSubtaskSet.size() / parallelism * 100 - >= ACCEPT_PARTIAL_AGGR_THRESHOLD) { - completedStatistics = inProgressStatistics; - LOG.info( - "Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}. " - + "Complete data statistics aggregation at checkpoint {} as it is more than the threshold of {} percentage", - inProgressSubtaskSet.size(), - parallelism, - operatorName, + Aggregation aggregation = + aggregationsPerCheckpoint.computeIfAbsent( checkpointId, - inProgressStatistics.checkpointId(), - ACCEPT_PARTIAL_AGGR_THRESHOLD); + ignored -> + new Aggregation( + parallelism, + downstreamParallelism, + switchToSketchThreshold, + statisticsType, + StatisticsUtil.collectType(statisticsType, completedStatistics))); + DataStatistics dataStatistics = + StatisticsUtil.deserializeDataStatistics(event.statisticsBytes(), statisticsSerializer); + if (!aggregation.merge(subtask, dataStatistics)) { + LOG.debug( + "Ignore duplicate data statistics from operator {} subtask {} for checkpoint {}.", + operatorName, + subtask, + checkpointId); + } + + if (aggregation.isComplete()) { + this.completedStatistics = aggregation.completedStatistics(checkpointId); + // clean up aggregations up to the completed checkpoint id + aggregationsPerCheckpoint.headMap(checkpointId, true).clear(); + return completedStatistics; + } + + return null; + } + + @VisibleForTesting + NavigableMap aggregationsPerCheckpoint() { + return aggregationsPerCheckpoint; + } + + static class Aggregation { + private static final Logger LOG = LoggerFactory.getLogger(Aggregation.class); + + private final Set subtaskSet; + private final int parallelism; + private final int downstreamParallelism; + private final int switchToSketchThreshold; + private final StatisticsType configuredType; + private StatisticsType currentType; + private Map mapStatistics; + private ReservoirItemsUnion sketchStatistics; + + Aggregation( + int parallelism, + int downstreamParallelism, + int switchToSketchThreshold, + StatisticsType configuredType, + StatisticsType currentType) { + this.subtaskSet = Sets.newHashSet(); + this.parallelism = parallelism; + this.downstreamParallelism = downstreamParallelism; + this.switchToSketchThreshold = switchToSketchThreshold; + this.configuredType = configuredType; + this.currentType = currentType; + + if (currentType == StatisticsType.Map) { + this.mapStatistics = Maps.newHashMap(); + this.sketchStatistics = null; } else { - LOG.info( - "Received data statistics from {} subtasks out of total {} for operator {} at checkpoint {}. " - + "Aborting the incomplete aggregation for checkpoint {}", - inProgressSubtaskSet.size(), - parallelism, - operatorName, - checkpointId, - inProgressStatistics.checkpointId()); + this.mapStatistics = null; + this.sketchStatistics = + ReservoirItemsUnion.newInstance( + SketchUtil.determineCoordinatorReservoirSize(downstreamParallelism)); } + } - inProgressStatistics = null; - inProgressSubtaskSet.clear(); + @VisibleForTesting + Set subtaskSet() { + return subtaskSet; } - if (inProgressStatistics == null) { - LOG.info("Starting a new data statistics for checkpoint {}", checkpointId); - inProgressStatistics = new AggregatedStatistics<>(checkpointId, statisticsSerializer); - inProgressSubtaskSet.clear(); + @VisibleForTesting + StatisticsType currentType() { + return currentType; } - if (!inProgressSubtaskSet.add(subtask)) { - LOG.debug( - "Ignore duplicated data statistics from operator {} subtask {} for checkpoint {}.", - operatorName, - subtask, - checkpointId); - } else { - inProgressStatistics.mergeDataStatistic( - operatorName, - event.checkpointId(), - DataStatisticsUtil.deserializeDataStatistics( - event.statisticsBytes(), statisticsSerializer)); + @VisibleForTesting + Map mapStatistics() { + return mapStatistics; } - if (inProgressSubtaskSet.size() == parallelism) { - completedStatistics = inProgressStatistics; - LOG.info( - "Received data statistics from all {} operators {} for checkpoint {}. Return last completed aggregator {}.", - parallelism, - operatorName, - inProgressStatistics.checkpointId(), - completedStatistics.dataStatistics()); - inProgressStatistics = new AggregatedStatistics<>(checkpointId + 1, statisticsSerializer); - inProgressSubtaskSet.clear(); + @VisibleForTesting + ReservoirItemsUnion sketchStatistics() { + return sketchStatistics; } - return completedStatistics; - } + private boolean isComplete() { + return subtaskSet.size() == parallelism; + } - @VisibleForTesting - AggregatedStatistics inProgressStatistics() { - return inProgressStatistics; + /** @return false if duplicate */ + private boolean merge(int subtask, DataStatistics taskStatistics) { + if (subtaskSet.contains(subtask)) { + return false; + } + + subtaskSet.add(subtask); + merge(taskStatistics); + return true; + } + + @SuppressWarnings("unchecked") + private void merge(DataStatistics taskStatistics) { + if (taskStatistics.type() == StatisticsType.Map) { + Map taskMapStats = (Map) taskStatistics.result(); + if (currentType == StatisticsType.Map) { + taskMapStats.forEach((key, count) -> mapStatistics.merge(key, count, Long::sum)); + if (configuredType == StatisticsType.Auto + && mapStatistics.size() > switchToSketchThreshold) { + convertCoordinatorToSketch(); + } + } else { + // convert task stats to sketch first + ReservoirItemsSketch taskSketch = + ReservoirItemsSketch.newInstance( + SketchUtil.determineOperatorReservoirSize(parallelism, downstreamParallelism)); + SketchUtil.convertMapToSketch(taskMapStats, taskSketch::update); + sketchStatistics.update(taskSketch); + } + } else { + ReservoirItemsSketch taskSketch = + (ReservoirItemsSketch) taskStatistics.result(); + if (currentType == StatisticsType.Map) { + // convert global stats to sketch first + convertCoordinatorToSketch(); + } + + sketchStatistics.update(taskSketch); + } + } + + private void convertCoordinatorToSketch() { + this.sketchStatistics = + ReservoirItemsUnion.newInstance( + SketchUtil.determineCoordinatorReservoirSize(downstreamParallelism)); + SketchUtil.convertMapToSketch(mapStatistics, sketchStatistics::update); + this.currentType = StatisticsType.Sketch; + this.mapStatistics = null; + } + + private CompletedStatistics completedStatistics(long checkpointId) { + if (currentType == StatisticsType.Map) { + LOG.info("Completed map statistics aggregation with {} keys", mapStatistics.size()); + return CompletedStatistics.fromKeyFrequency(checkpointId, mapStatistics); + } else { + ReservoirItemsSketch sketch = sketchStatistics.getResult(); + LOG.info( + "Completed sketch statistics aggregation: " + + "reservoir size = {}, number of items seen = {}, number of samples = {}", + sketch.getK(), + sketch.getN(), + sketch.getNumSamples()); + return CompletedStatistics.fromKeySamples(checkpointId, sketch.getSamples()); + } + } } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatistics.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatistics.java new file mode 100644 index 000000000000..c0e228965ddd --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatistics.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.Map; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +/** + * This is what {@link AggregatedStatisticsTracker} returns upon a completed statistics aggregation + * from all subtasks. It contains the raw statistics (Map or reservoir samples). + */ +class CompletedStatistics { + private final long checkpointId; + private final StatisticsType type; + private final Map keyFrequency; + private final SortKey[] keySamples; + + static CompletedStatistics fromKeyFrequency(long checkpointId, Map stats) { + return new CompletedStatistics(checkpointId, StatisticsType.Map, stats, null); + } + + static CompletedStatistics fromKeySamples(long checkpointId, SortKey[] keySamples) { + return new CompletedStatistics(checkpointId, StatisticsType.Sketch, null, keySamples); + } + + CompletedStatistics( + long checkpointId, + StatisticsType type, + Map keyFrequency, + SortKey[] keySamples) { + this.checkpointId = checkpointId; + this.type = type; + this.keyFrequency = keyFrequency; + this.keySamples = keySamples; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("checkpointId", checkpointId) + .add("type", type) + .add("keyFrequency", keyFrequency) + .add("keySamples", keySamples) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof CompletedStatistics)) { + return false; + } + + CompletedStatistics other = (CompletedStatistics) o; + return Objects.equal(checkpointId, other.checkpointId) + && Objects.equal(type, other.type) + && Objects.equal(keyFrequency, other.keyFrequency()) + && Arrays.equals(keySamples, other.keySamples()); + } + + @Override + public int hashCode() { + return Objects.hashCode(checkpointId, type, keyFrequency, keySamples); + } + + long checkpointId() { + return checkpointId; + } + + StatisticsType type() { + return type; + } + + Map keyFrequency() { + return keyFrequency; + } + + SortKey[] keySamples() { + return keySamples; + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatisticsSerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatisticsSerializer.java new file mode 100644 index 000000000000..7f55188e7f8c --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/CompletedStatisticsSerializer.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; + +class CompletedStatisticsSerializer extends TypeSerializer { + private final TypeSerializer sortKeySerializer; + private final EnumSerializer statisticsTypeSerializer; + private final MapSerializer keyFrequencySerializer; + private final ListSerializer keySamplesSerializer; + + CompletedStatisticsSerializer(TypeSerializer sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.keyFrequencySerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE); + this.keySamplesSerializer = new ListSerializer<>(sortKeySerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return new CompletedStatisticsSerializer(sortKeySerializer); + } + + @Override + public CompletedStatistics createInstance() { + return CompletedStatistics.fromKeyFrequency(0L, Collections.emptyMap()); + } + + @Override + public CompletedStatistics copy(CompletedStatistics from) { + return new CompletedStatistics( + from.checkpointId(), from.type(), from.keyFrequency(), from.keySamples()); + } + + @Override + public CompletedStatistics copy(CompletedStatistics from, CompletedStatistics reuse) { + // no benefit of reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(CompletedStatistics record, DataOutputView target) throws IOException { + target.writeLong(record.checkpointId()); + statisticsTypeSerializer.serialize(record.type(), target); + if (record.type() == StatisticsType.Map) { + keyFrequencySerializer.serialize(record.keyFrequency(), target); + } else { + keySamplesSerializer.serialize(Arrays.asList(record.keySamples()), target); + } + } + + @Override + public CompletedStatistics deserialize(DataInputView source) throws IOException { + long checkpointId = source.readLong(); + StatisticsType type = statisticsTypeSerializer.deserialize(source); + if (type == StatisticsType.Map) { + Map keyFrequency = keyFrequencySerializer.deserialize(source); + return CompletedStatistics.fromKeyFrequency(checkpointId, keyFrequency); + } else { + List sortKeys = keySamplesSerializer.deserialize(source); + SortKey[] keySamples = new SortKey[sortKeys.size()]; + keySamples = sortKeys.toArray(keySamples); + return CompletedStatistics.fromKeySamples(checkpointId, keySamples); + } + } + + @Override + public CompletedStatistics deserialize(CompletedStatistics reuse, DataInputView source) + throws IOException { + // not much benefit to reuse + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + CompletedStatisticsSerializer other = (CompletedStatisticsSerializer) obj; + return Objects.equals(sortKeySerializer, other.sortKeySerializer); + } + + @Override + public int hashCode() { + return sortKeySerializer.hashCode(); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new CompletedStatisticsSerializerSnapshot(this); + } + + public static class CompletedStatisticsSerializerSnapshot + extends CompositeTypeSerializerSnapshot { + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public CompletedStatisticsSerializerSnapshot() { + super(CompletedStatisticsSerializer.class); + } + + @SuppressWarnings("checkstyle:RedundantModifier") + public CompletedStatisticsSerializerSnapshot(CompletedStatisticsSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers( + CompletedStatisticsSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.sortKeySerializer}; + } + + @Override + protected CompletedStatisticsSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + SortKeySerializer sortKeySerializer = (SortKeySerializer) nestedSerializers[0]; + return new CompletedStatisticsSerializer(sortKeySerializer); + } + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java index 9d7cf179ab1c..76c59cd5f4b8 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Map; +import org.apache.datasketches.sampling.ReservoirItemsSketch; import org.apache.flink.annotation.Internal; import org.apache.iceberg.SortKey; @@ -29,29 +31,18 @@ * (sketching) can be used. */ @Internal -interface DataStatistics, S> { +interface DataStatistics { + + StatisticsType type(); - /** - * Check if data statistics contains any statistics information. - * - * @return true if data statistics doesn't contain any statistics information - */ boolean isEmpty(); /** Add row sortKey to data statistics. */ void add(SortKey sortKey); /** - * Merge current statistics with other statistics. - * - * @param otherStatistics the statistics to be merged - */ - void merge(D otherStatistics); - - /** - * Get the underline statistics. - * - * @return the underline statistics + * Get the collected statistics. Could be a {@link Map} (low cardinality) or {@link + * ReservoirItemsSketch} (high cardinality) */ - S statistics(); + Object result(); } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java index c8ac79c61bf6..3b21fbae315a 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Comparator; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; @@ -35,6 +36,10 @@ import org.apache.flink.util.Preconditions; import org.apache.flink.util.ThrowableCatchingRunnable; import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Maps; @@ -44,51 +49,86 @@ import org.slf4j.LoggerFactory; /** - * DataStatisticsCoordinator receives {@link DataStatisticsEvent} from {@link - * DataStatisticsOperator} every subtask and then merge them together. Once aggregation for all - * subtasks data statistics completes, DataStatisticsCoordinator will send the aggregated data - * statistics back to {@link DataStatisticsOperator}. In the end a custom partitioner will - * distribute traffic based on the aggregated data statistics to improve data clustering. + * DataStatisticsCoordinator receives {@link StatisticsEvent} from {@link DataStatisticsOperator} + * every subtask and then merge them together. Once aggregation for all subtasks data statistics + * completes, DataStatisticsCoordinator will send the aggregated data statistics back to {@link + * DataStatisticsOperator}. In the end a custom partitioner will distribute traffic based on the + * aggregated data statistics to improve data clustering. */ @Internal -class DataStatisticsCoordinator, S> implements OperatorCoordinator { +class DataStatisticsCoordinator implements OperatorCoordinator { private static final Logger LOG = LoggerFactory.getLogger(DataStatisticsCoordinator.class); private final String operatorName; + private final OperatorCoordinator.Context context; + private final Schema schema; + private final SortOrder sortOrder; + private final Comparator comparator; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final double closeFileCostWeightPercentage; + private final ExecutorService coordinatorExecutor; - private final OperatorCoordinator.Context operatorCoordinatorContext; private final SubtaskGateways subtaskGateways; private final CoordinatorExecutorThreadFactory coordinatorThreadFactory; - private final TypeSerializer> statisticsSerializer; - private final transient AggregatedStatisticsTracker aggregatedStatisticsTracker; - private volatile AggregatedStatistics completedStatistics; - private volatile boolean started; + private final TypeSerializer completedStatisticsSerializer; + private final TypeSerializer globalStatisticsSerializer; + + private transient boolean started; + private transient AggregatedStatisticsTracker aggregatedStatisticsTracker; + private transient CompletedStatistics completedStatistics; + private transient GlobalStatistics globalStatistics; DataStatisticsCoordinator( String operatorName, OperatorCoordinator.Context context, - TypeSerializer> statisticsSerializer) { + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, + StatisticsType statisticsType, + double closeFileCostWeightPercentage) { this.operatorName = operatorName; + this.context = context; + this.schema = schema; + this.sortOrder = sortOrder; + this.comparator = SortOrderComparators.forSchema(schema, sortOrder); + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + this.closeFileCostWeightPercentage = closeFileCostWeightPercentage; + this.coordinatorThreadFactory = new CoordinatorExecutorThreadFactory( "DataStatisticsCoordinator-" + operatorName, context.getUserCodeClassloader()); this.coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); - this.operatorCoordinatorContext = context; - this.subtaskGateways = new SubtaskGateways(operatorName, parallelism()); - this.statisticsSerializer = statisticsSerializer; - this.aggregatedStatisticsTracker = - new AggregatedStatisticsTracker<>(operatorName, statisticsSerializer, parallelism()); + this.subtaskGateways = new SubtaskGateways(operatorName, context.currentParallelism()); + SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, sortOrder); + this.completedStatisticsSerializer = new CompletedStatisticsSerializer(sortKeySerializer); + this.globalStatisticsSerializer = new GlobalStatisticsSerializer(sortKeySerializer); } @Override public void start() throws Exception { LOG.info("Starting data statistics coordinator: {}.", operatorName); - started = true; + this.started = true; + + // statistics are restored already in resetToCheckpoint() before start() called + this.aggregatedStatisticsTracker = + new AggregatedStatisticsTracker( + operatorName, + context.currentParallelism(), + schema, + sortOrder, + downstreamParallelism, + statisticsType, + SketchUtil.COORDINATOR_SKETCH_SWITCH_THRESHOLD, + completedStatistics); } @Override public void close() throws Exception { coordinatorExecutor.shutdown(); + this.aggregatedStatisticsTracker = null; + this.started = false; LOG.info("Closed data statistics coordinator: {}.", operatorName); } @@ -148,7 +188,7 @@ private void runInCoordinatorThread(ThrowingRunnable action, String a operatorName, actionString, t); - operatorCoordinatorContext.failJob(t); + context.failJob(t); } }); } @@ -157,42 +197,102 @@ private void ensureStarted() { Preconditions.checkState(started, "The coordinator of %s has not started yet.", operatorName); } - private int parallelism() { - return operatorCoordinatorContext.currentParallelism(); - } - - private void handleDataStatisticRequest(int subtask, DataStatisticsEvent event) { - AggregatedStatistics aggregatedStatistics = + private void handleDataStatisticRequest(int subtask, StatisticsEvent event) { + CompletedStatistics maybeCompletedStatistics = aggregatedStatisticsTracker.updateAndCheckCompletion(subtask, event); - if (aggregatedStatistics != null) { - completedStatistics = aggregatedStatistics; - sendDataStatisticsToSubtasks( - completedStatistics.checkpointId(), completedStatistics.dataStatistics()); + if (maybeCompletedStatistics != null) { + // completedStatistics contains the complete samples, which is needed to compute + // the range bounds in globalStatistics if downstreamParallelism changed. + this.completedStatistics = maybeCompletedStatistics; + // globalStatistics only contains assignment calculated based on Map or Sketch statistics + this.globalStatistics = + globalStatistics( + maybeCompletedStatistics, + downstreamParallelism, + comparator, + closeFileCostWeightPercentage); + sendGlobalStatisticsToSubtasks(globalStatistics); + } + } + + private static GlobalStatistics globalStatistics( + CompletedStatistics completedStatistics, + int downstreamParallelism, + Comparator comparator, + double closeFileCostWeightPercentage) { + if (completedStatistics.type() == StatisticsType.Sketch) { + // range bound is a much smaller array compared to the complete samples. + // It helps reduce the amount of data transfer from coordinator to operator subtasks. + return GlobalStatistics.fromRangeBounds( + completedStatistics.checkpointId(), + SketchUtil.rangeBounds( + downstreamParallelism, comparator, completedStatistics.keySamples())); + } else { + return GlobalStatistics.fromMapAssignment( + completedStatistics.checkpointId(), + MapAssignment.fromKeyFrequency( + downstreamParallelism, + completedStatistics.keyFrequency(), + closeFileCostWeightPercentage, + comparator)); } } @SuppressWarnings("FutureReturnValueIgnored") - private void sendDataStatisticsToSubtasks( - long checkpointId, DataStatistics globalDataStatistics) { - callInCoordinatorThread( + private void sendGlobalStatisticsToSubtasks(GlobalStatistics statistics) { + runInCoordinatorThread( () -> { - DataStatisticsEvent dataStatisticsEvent = - DataStatisticsEvent.create(checkpointId, globalDataStatistics, statisticsSerializer); - int parallelism = parallelism(); - for (int i = 0; i < parallelism; ++i) { - subtaskGateways.getSubtaskGateway(i).sendEvent(dataStatisticsEvent); + LOG.info( + "Broadcast latest global statistics from checkpoint {} to all subtasks", + statistics.checkpointId()); + // applyImmediately is set to false so that operator subtasks can + // apply the change at checkpoint boundary + StatisticsEvent statisticsEvent = + StatisticsEvent.createGlobalStatisticsEvent( + statistics, globalStatisticsSerializer, false); + for (int i = 0; i < context.currentParallelism(); ++i) { + // Ignore future return value for potential error (e.g. subtask down). + // Upon restart, subtasks send request to coordinator to refresh statistics + // if there is any difference + subtaskGateways.getSubtaskGateway(i).sendEvent(statisticsEvent); } - - return null; }, String.format( "Failed to send operator %s coordinator global data statistics for checkpoint %d", - operatorName, checkpointId)); + operatorName, statistics.checkpointId())); + } + + @SuppressWarnings("FutureReturnValueIgnored") + private void handleRequestGlobalStatisticsEvent(int subtask, RequestGlobalStatisticsEvent event) { + if (globalStatistics != null) { + runInCoordinatorThread( + () -> { + if (event.signature() != null && event.signature() != globalStatistics.hashCode()) { + LOG.debug( + "Skip responding to statistics request from subtask {}, as hashCode matches or not included in the request", + subtask); + } else { + LOG.info( + "Send latest global statistics from checkpoint {} to subtask {}", + globalStatistics.checkpointId(), + subtask); + StatisticsEvent statisticsEvent = + StatisticsEvent.createGlobalStatisticsEvent( + globalStatistics, globalStatisticsSerializer, true); + subtaskGateways.getSubtaskGateway(subtask).sendEvent(statisticsEvent); + } + }, + String.format( + "Failed to send operator %s coordinator global data statistics to requesting subtask %d for checkpoint %d", + operatorName, subtask, globalStatistics.checkpointId())); + } else { + LOG.info( + "Ignore global statistics request from subtask {} as statistics not available", subtask); + } } @Override - @SuppressWarnings("unchecked") public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEvent event) { runInCoordinatorThread( () -> { @@ -202,8 +302,14 @@ public void handleEventFromOperator(int subtask, int attemptNumber, OperatorEven attemptNumber, operatorName, event); - Preconditions.checkArgument(event instanceof DataStatisticsEvent); - handleDataStatisticRequest(subtask, ((DataStatisticsEvent) event)); + if (event instanceof StatisticsEvent) { + handleDataStatisticRequest(subtask, ((StatisticsEvent) event)); + } else if (event instanceof RequestGlobalStatisticsEvent) { + handleRequestGlobalStatisticsEvent(subtask, (RequestGlobalStatisticsEvent) event); + } else { + throw new IllegalArgumentException( + "Invalid operator event type: " + event.getClass().getCanonicalName()); + } }, String.format( "handling operator event %s from subtask %d (#%d)", @@ -219,8 +325,8 @@ public void checkpointCoordinator(long checkpointId, CompletableFuture r operatorName, checkpointId); resultFuture.complete( - DataStatisticsUtil.serializeAggregatedStatistics( - completedStatistics, statisticsSerializer)); + StatisticsUtil.serializeCompletedStatistics( + completedStatistics, completedStatisticsSerializer)); }, String.format("taking checkpoint %d", checkpointId)); } @@ -229,11 +335,9 @@ public void checkpointCoordinator(long checkpointId, CompletableFuture r public void notifyCheckpointComplete(long checkpointId) {} @Override - public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData) - throws Exception { + public void resetToCheckpoint(long checkpointId, byte[] checkpointData) { Preconditions.checkState( !started, "The coordinator %s can only be reset if it was not yet started", operatorName); - if (checkpointData == null) { LOG.info( "Data statistic coordinator {} has nothing to restore from checkpoint {}", @@ -244,8 +348,13 @@ public void resetToCheckpoint(long checkpointId, @Nullable byte[] checkpointData LOG.info( "Restoring data statistic coordinator {} from checkpoint {}", operatorName, checkpointId); - completedStatistics = - DataStatisticsUtil.deserializeAggregatedStatistics(checkpointData, statisticsSerializer); + this.completedStatistics = + StatisticsUtil.deserializeCompletedStatistics( + checkpointData, completedStatisticsSerializer); + // recompute global statistics in case downstream parallelism changed + this.globalStatistics = + globalStatistics( + completedStatistics, downstreamParallelism, comparator, closeFileCostWeightPercentage); } @Override @@ -269,7 +378,7 @@ public void executionAttemptFailed(int subtask, int attemptNumber, @Nullable Thr runInCoordinatorThread( () -> { LOG.info( - "Unregistering gateway after failure for subtask {} (#{}) of data statistic {}", + "Unregistering gateway after failure for subtask {} (#{}) of data statistics {}", subtask, attemptNumber, operatorName); @@ -295,14 +404,20 @@ public void executionAttemptReady(int subtask, int attemptNumber, SubtaskGateway } @VisibleForTesting - AggregatedStatistics completedStatistics() { + CompletedStatistics completedStatistics() { return completedStatistics; } + @VisibleForTesting + GlobalStatistics globalStatistics() { + return globalStatistics; + } + private static class SubtaskGateways { private final String operatorName; private final Map[] gateways; + @SuppressWarnings("unchecked") private SubtaskGateways(String operatorName, int parallelism) { this.operatorName = operatorName; gateways = new Map[parallelism]; diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java index 47dbfc3cfbe1..9d7d989c298e 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinatorProvider.java @@ -19,33 +19,52 @@ package org.apache.iceberg.flink.sink.shuffle; import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; /** * DataStatisticsCoordinatorProvider provides the method to create new {@link * DataStatisticsCoordinator} */ @Internal -public class DataStatisticsCoordinatorProvider, S> - extends RecreateOnResetOperatorCoordinator.Provider { +public class DataStatisticsCoordinatorProvider extends RecreateOnResetOperatorCoordinator.Provider { private final String operatorName; - private final TypeSerializer> statisticsSerializer; + private final Schema schema; + private final SortOrder sortOrder; + private final int downstreamParallelism; + private final StatisticsType type; + private final double closeFileCostWeightPercentage; public DataStatisticsCoordinatorProvider( String operatorName, OperatorID operatorID, - TypeSerializer> statisticsSerializer) { + Schema schema, + SortOrder sortOrder, + int downstreamParallelism, + StatisticsType type, + double closeFileCostWeightPercentage) { super(operatorID); this.operatorName = operatorName; - this.statisticsSerializer = statisticsSerializer; + this.schema = schema; + this.sortOrder = sortOrder; + this.downstreamParallelism = downstreamParallelism; + this.type = type; + this.closeFileCostWeightPercentage = closeFileCostWeightPercentage; } @Override public OperatorCoordinator getCoordinator(OperatorCoordinator.Context context) { - return new DataStatisticsCoordinator<>(operatorName, context, statisticsSerializer); + return new DataStatisticsCoordinator( + operatorName, + context, + schema, + sortOrder, + downstreamParallelism, + type, + closeFileCostWeightPercentage); } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java index 5157a37cf2cd..59c38b239725 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import java.util.Map; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; @@ -47,9 +48,8 @@ * distribution to downstream subtasks. */ @Internal -class DataStatisticsOperator, S> - extends AbstractStreamOperator> - implements OneInputStreamOperator>, OperatorEventHandler { +public class DataStatisticsOperator extends AbstractStreamOperator + implements OneInputStreamOperator, OperatorEventHandler { private static final long serialVersionUID = 1L; @@ -57,141 +57,209 @@ class DataStatisticsOperator, S> private final RowDataWrapper rowDataWrapper; private final SortKey sortKey; private final OperatorEventGateway operatorEventGateway; - private final TypeSerializer> statisticsSerializer; - private transient volatile DataStatistics localStatistics; - private transient volatile DataStatistics globalStatistics; - private transient ListState> globalStatisticsState; + private final int downstreamParallelism; + private final StatisticsType statisticsType; + private final TypeSerializer taskStatisticsSerializer; + private final TypeSerializer globalStatisticsSerializer; + + private transient int parallelism; + private transient int subtaskIndex; + private transient ListState globalStatisticsState; + // current statistics type may be different from the config due to possible + // migration from Map statistics to Sketch statistics when high cardinality detected + private transient volatile StatisticsType taskStatisticsType; + private transient volatile DataStatistics localStatistics; + private transient volatile GlobalStatistics globalStatistics; DataStatisticsOperator( String operatorName, Schema schema, SortOrder sortOrder, OperatorEventGateway operatorEventGateway, - TypeSerializer> statisticsSerializer) { + int downstreamParallelism, + StatisticsType statisticsType) { this.operatorName = operatorName; this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); this.sortKey = new SortKey(schema, sortOrder); this.operatorEventGateway = operatorEventGateway; - this.statisticsSerializer = statisticsSerializer; + this.downstreamParallelism = downstreamParallelism; + this.statisticsType = statisticsType; + + SortKeySerializer sortKeySerializer = new SortKeySerializer(schema, sortOrder); + this.taskStatisticsSerializer = new DataStatisticsSerializer(sortKeySerializer); + this.globalStatisticsSerializer = new GlobalStatisticsSerializer(sortKeySerializer); } @Override public void initializeState(StateInitializationContext context) throws Exception { - localStatistics = statisticsSerializer.createInstance(); - globalStatisticsState = + this.parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + this.subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); + + // Use union state so that new subtasks can also restore global statistics during scale-up. + this.globalStatisticsState = context .getOperatorStateStore() .getUnionListState( - new ListStateDescriptor<>("globalStatisticsState", statisticsSerializer)); + new ListStateDescriptor<>("globalStatisticsState", globalStatisticsSerializer)); if (context.isRestored()) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); if (globalStatisticsState.get() == null || !globalStatisticsState.get().iterator().hasNext()) { - LOG.warn( + LOG.info( "Operator {} subtask {} doesn't have global statistics state to restore", operatorName, subtaskIndex); - globalStatistics = statisticsSerializer.createInstance(); + // If Flink deprecates union state in the future, RequestGlobalStatisticsEvent can be + // leveraged to request global statistics from coordinator if new subtasks (scale-up case) + // has nothing to restore from. } else { + GlobalStatistics restoredStatistics = globalStatisticsState.get().iterator().next(); LOG.info( - "Restoring operator {} global statistics state for subtask {}", - operatorName, - subtaskIndex); - globalStatistics = globalStatisticsState.get().iterator().next(); + "Operator {} subtask {} restored global statistics state", operatorName, subtaskIndex); + this.globalStatistics = restoredStatistics; } - } else { - globalStatistics = statisticsSerializer.createInstance(); + + // Always request for new statistics from coordinator upon task initialization. + // There are a few scenarios this is needed + // 1. downstream writer parallelism changed due to rescale. + // 2. coordinator failed to send the aggregated statistics to subtask + // (e.g. due to subtask failure at the time). + // Records may flow before coordinator can respond. Range partitioner should be + // able to continue to operate with potentially suboptimal behavior (in sketch case). + LOG.info( + "Operator {} subtask {} requests new global statistics from coordinator ", + operatorName, + subtaskIndex); + // coordinator can use the hashCode (if available) in the request event to determine + // if operator already has the latest global statistics and respond can be skipped. + // This makes the handling cheap in most situations. + RequestGlobalStatisticsEvent event = + globalStatistics != null + ? new RequestGlobalStatisticsEvent(globalStatistics.hashCode()) + : new RequestGlobalStatisticsEvent(); + operatorEventGateway.sendEventToCoordinator(event); } + + this.taskStatisticsType = StatisticsUtil.collectType(statisticsType, globalStatistics); + this.localStatistics = + StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, downstreamParallelism); } @Override public void open() throws Exception { - if (!globalStatistics.isEmpty()) { - output.collect( - new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + if (globalStatistics != null) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics))); } } @Override - @SuppressWarnings("unchecked") public void handleOperatorEvent(OperatorEvent event) { - int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask(); Preconditions.checkArgument( - event instanceof DataStatisticsEvent, + event instanceof StatisticsEvent, String.format( "Operator %s subtask %s received unexpected operator event %s", operatorName, subtaskIndex, event.getClass())); - DataStatisticsEvent statisticsEvent = (DataStatisticsEvent) event; + StatisticsEvent statisticsEvent = (StatisticsEvent) event; LOG.info( - "Operator {} received global data event from coordinator checkpoint {}", + "Operator {} subtask {} received global data event from coordinator checkpoint {}", operatorName, + subtaskIndex, statisticsEvent.checkpointId()); - globalStatistics = - DataStatisticsUtil.deserializeDataStatistics( - statisticsEvent.statisticsBytes(), statisticsSerializer); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + this.globalStatistics = + StatisticsUtil.deserializeGlobalStatistics( + statisticsEvent.statisticsBytes(), globalStatisticsSerializer); + checkStatisticsTypeMigration(); + // if applyImmediately not set, wait until the checkpoint time to switch + if (statisticsEvent.applyImmediately()) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics))); + } } @Override public void processElement(StreamRecord streamRecord) { + // collect data statistics RowData record = streamRecord.getValue(); StructLike struct = rowDataWrapper.wrap(record); sortKey.wrap(struct); localStatistics.add(sortKey); - output.collect(new StreamRecord<>(DataStatisticsOrRecord.fromRecord(record))); + + checkStatisticsTypeMigration(); + output.collect(new StreamRecord<>(StatisticsOrRecord.fromRecord(record))); } @Override public void snapshotState(StateSnapshotContext context) throws Exception { long checkpointId = context.getCheckpointId(); - int subTaskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info( - "Snapshotting data statistics operator {} for checkpoint {} in subtask {}", + "Operator {} subtask {} snapshotting data statistics for checkpoint {}", operatorName, - checkpointId, - subTaskId); + subtaskIndex, + checkpointId); - // Pass global statistics to partitioners so that all the operators refresh statistics + // Pass global statistics to partitioner so that all the operators refresh statistics // at same checkpoint barrier - if (!globalStatistics.isEmpty()) { - output.collect( - new StreamRecord<>(DataStatisticsOrRecord.fromDataStatistics(globalStatistics))); + if (globalStatistics != null) { + output.collect(new StreamRecord<>(StatisticsOrRecord.fromStatistics(globalStatistics))); } // Only subtask 0 saves the state so that globalStatisticsState(UnionListState) stores // an exact copy of globalStatistics - if (!globalStatistics.isEmpty() && getRuntimeContext().getIndexOfThisSubtask() == 0) { + if (globalStatistics != null && getRuntimeContext().getIndexOfThisSubtask() == 0) { globalStatisticsState.clear(); LOG.info( - "Saving operator {} global statistics {} to state in subtask {}", - operatorName, - globalStatistics, - subTaskId); + "Operator {} subtask {} saving global statistics to state", operatorName, subtaskIndex); globalStatisticsState.add(globalStatistics); + LOG.debug( + "Operator {} subtask {} saved global statistics to state: {}", + operatorName, + subtaskIndex, + globalStatistics); } // For now, local statistics are sent to coordinator at checkpoint - operatorEventGateway.sendEventToCoordinator( - DataStatisticsEvent.create(checkpointId, localStatistics, statisticsSerializer)); - LOG.debug( - "Subtask {} of operator {} sent local statistics to coordinator at checkpoint{}: {}", - subTaskId, + LOG.info( + "Operator {} Subtask {} sending local statistics to coordinator for checkpoint {}", operatorName, - checkpointId, - localStatistics); + subtaskIndex, + checkpointId); + operatorEventGateway.sendEventToCoordinator( + StatisticsEvent.createTaskStatisticsEvent( + checkpointId, localStatistics, taskStatisticsSerializer)); // Recreate the local statistics - localStatistics = statisticsSerializer.createInstance(); + localStatistics = + StatisticsUtil.createTaskStatistics(taskStatisticsType, parallelism, downstreamParallelism); + } + + @SuppressWarnings("unchecked") + private void checkStatisticsTypeMigration() { + // only check if the statisticsType config is Auto and localStatistics is currently Map type + if (statisticsType == StatisticsType.Auto && localStatistics.type() == StatisticsType.Map) { + Map mapStatistics = (Map) localStatistics.result(); + // convert if local statistics has cardinality over the threshold or + // if received global statistics is already sketch type + if (mapStatistics.size() > SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD + || (globalStatistics != null && globalStatistics.type() == StatisticsType.Sketch)) { + LOG.info( + "Operator {} subtask {} switched local statistics from Map to Sketch.", + operatorName, + subtaskIndex); + this.taskStatisticsType = StatisticsType.Sketch; + this.localStatistics = + StatisticsUtil.createTaskStatistics( + taskStatisticsType, parallelism, downstreamParallelism); + SketchUtil.convertMapToSketch(mapStatistics, localStatistics::add); + } + } } @VisibleForTesting - DataStatistics localDataStatistics() { + DataStatistics localStatistics() { return localStatistics; } @VisibleForTesting - DataStatistics globalDataStatistics() { + GlobalStatistics globalStatistics() { return globalStatistics; } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java new file mode 100644 index 000000000000..c25481b3c1f2 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsSerializer.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +@Internal +class DataStatisticsSerializer extends TypeSerializer { + private final TypeSerializer sortKeySerializer; + private final EnumSerializer statisticsTypeSerializer; + private final MapSerializer mapSerializer; + private final SortKeySketchSerializer sketchSerializer; + + DataStatisticsSerializer(TypeSerializer sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.mapSerializer = new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE); + this.sketchSerializer = new SortKeySketchSerializer(sortKeySerializer); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @SuppressWarnings("ReferenceEquality") + @Override + public TypeSerializer duplicate() { + TypeSerializer duplicateSortKeySerializer = sortKeySerializer.duplicate(); + return (duplicateSortKeySerializer == sortKeySerializer) + ? this + : new DataStatisticsSerializer(duplicateSortKeySerializer); + } + + @Override + public DataStatistics createInstance() { + return new MapDataStatistics(); + } + + @SuppressWarnings("unchecked") + @Override + public DataStatistics copy(DataStatistics obj) { + StatisticsType statisticsType = obj.type(); + if (statisticsType == StatisticsType.Map) { + MapDataStatistics from = (MapDataStatistics) obj; + Map fromStats = (Map) from.result(); + Map toStats = Maps.newHashMap(fromStats); + return new MapDataStatistics(toStats); + } else if (statisticsType == StatisticsType.Sketch) { + // because ReservoirItemsSketch doesn't expose enough public methods for cloning, + // this implementation adopted the less efficient serialization and deserialization. + SketchDataStatistics from = (SketchDataStatistics) obj; + ReservoirItemsSketch fromStats = (ReservoirItemsSketch) from.result(); + byte[] bytes = fromStats.toByteArray(sketchSerializer); + Memory memory = Memory.wrap(bytes); + ReservoirItemsSketch toStats = + ReservoirItemsSketch.heapify(memory, sketchSerializer); + return new SketchDataStatistics(toStats); + } else { + throw new IllegalArgumentException("Unsupported data statistics type: " + statisticsType); + } + } + + @Override + public DataStatistics copy(DataStatistics from, DataStatistics reuse) { + // not much benefit to reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @SuppressWarnings("unchecked") + @Override + public void serialize(DataStatistics obj, DataOutputView target) throws IOException { + StatisticsType statisticsType = obj.type(); + statisticsTypeSerializer.serialize(obj.type(), target); + if (statisticsType == StatisticsType.Map) { + Map mapStatistics = (Map) obj.result(); + mapSerializer.serialize(mapStatistics, target); + } else if (statisticsType == StatisticsType.Sketch) { + ReservoirItemsSketch sketch = (ReservoirItemsSketch) obj.result(); + byte[] sketchBytes = sketch.toByteArray(sketchSerializer); + target.writeInt(sketchBytes.length); + target.write(sketchBytes); + } else { + throw new IllegalArgumentException("Unsupported data statistics type: " + statisticsType); + } + } + + @Override + public DataStatistics deserialize(DataInputView source) throws IOException { + StatisticsType statisticsType = statisticsTypeSerializer.deserialize(source); + if (statisticsType == StatisticsType.Map) { + Map mapStatistics = mapSerializer.deserialize(source); + return new MapDataStatistics(mapStatistics); + } else if (statisticsType == StatisticsType.Sketch) { + int numBytes = source.readInt(); + byte[] sketchBytes = new byte[numBytes]; + source.read(sketchBytes); + Memory sketchMemory = Memory.wrap(sketchBytes); + ReservoirItemsSketch sketch = + ReservoirItemsSketch.heapify(sketchMemory, sketchSerializer); + return new SketchDataStatistics(sketch); + } else { + throw new IllegalArgumentException("Unsupported data statistics type: " + statisticsType); + } + } + + @Override + public DataStatistics deserialize(DataStatistics reuse, DataInputView source) throws IOException { + // not much benefit to reuse + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof DataStatisticsSerializer)) { + return false; + } + + DataStatisticsSerializer other = (DataStatisticsSerializer) obj; + return Objects.equals(sortKeySerializer, other.sortKeySerializer); + } + + @Override + public int hashCode() { + return sortKeySerializer.hashCode(); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new DataStatisticsSerializerSnapshot(this); + } + + public static class DataStatisticsSerializerSnapshot + extends CompositeTypeSerializerSnapshot { + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public DataStatisticsSerializerSnapshot() { + super(DataStatisticsSerializer.class); + } + + @SuppressWarnings("checkstyle:RedundantModifier") + public DataStatisticsSerializerSnapshot(DataStatisticsSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers(DataStatisticsSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.sortKeySerializer}; + } + + @Override + protected DataStatisticsSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + SortKeySerializer sortKeySerializer = (SortKeySerializer) nestedSerializers[0]; + return new DataStatisticsSerializer(sortKeySerializer); + } + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java deleted file mode 100644 index 8716cb872d0e..000000000000 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.core.memory.DataInputDeserializer; -import org.apache.flink.core.memory.DataOutputSerializer; - -/** - * DataStatisticsUtil is the utility to serialize and deserialize {@link DataStatistics} and {@link - * AggregatedStatistics} - */ -class DataStatisticsUtil { - - private DataStatisticsUtil() {} - - static , S> byte[] serializeDataStatistics( - DataStatistics dataStatistics, - TypeSerializer> statisticsSerializer) { - DataOutputSerializer out = new DataOutputSerializer(64); - try { - statisticsSerializer.serialize(dataStatistics, out); - return out.getCopyOfBuffer(); - } catch (IOException e) { - throw new IllegalStateException("Fail to serialize data statistics", e); - } - } - - @SuppressWarnings("unchecked") - static , S> D deserializeDataStatistics( - byte[] bytes, TypeSerializer> statisticsSerializer) { - DataInputDeserializer input = new DataInputDeserializer(bytes, 0, bytes.length); - try { - return (D) statisticsSerializer.deserialize(input); - } catch (IOException e) { - throw new IllegalStateException("Fail to deserialize data statistics", e); - } - } - - static , S> byte[] serializeAggregatedStatistics( - AggregatedStatistics aggregatedStatistics, - TypeSerializer> statisticsSerializer) - throws IOException { - ByteArrayOutputStream bytes = new ByteArrayOutputStream(); - ObjectOutputStream out = new ObjectOutputStream(bytes); - - DataOutputSerializer outSerializer = new DataOutputSerializer(64); - out.writeLong(aggregatedStatistics.checkpointId()); - statisticsSerializer.serialize(aggregatedStatistics.dataStatistics(), outSerializer); - byte[] statisticsBytes = outSerializer.getCopyOfBuffer(); - out.writeInt(statisticsBytes.length); - out.write(statisticsBytes); - out.flush(); - - return bytes.toByteArray(); - } - - static , S> - AggregatedStatistics deserializeAggregatedStatistics( - byte[] bytes, TypeSerializer> statisticsSerializer) - throws IOException { - ByteArrayInputStream bytesIn = new ByteArrayInputStream(bytes); - ObjectInputStream in = new ObjectInputStream(bytesIn); - - long completedCheckpointId = in.readLong(); - int statisticsBytesLength = in.readInt(); - byte[] statisticsBytes = new byte[statisticsBytesLength]; - in.readFully(statisticsBytes); - DataInputDeserializer input = - new DataInputDeserializer(statisticsBytes, 0, statisticsBytesLength); - DataStatistics dataStatistics = statisticsSerializer.deserialize(input); - - return new AggregatedStatistics<>(completedCheckpointId, dataStatistics); - } -} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatistics.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatistics.java new file mode 100644 index 000000000000..50ec23e9f7a2 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatistics.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** + * This is used by {@link RangePartitioner} for guiding range partitioning. This is what is sent to + * the operator subtasks. For sketch statistics, it only contains much smaller range bounds than the + * complete raw samples. + */ +class GlobalStatistics { + private final long checkpointId; + private final StatisticsType type; + private final MapAssignment mapAssignment; + private final SortKey[] rangeBounds; + + private transient Integer hashCode; + + GlobalStatistics( + long checkpointId, StatisticsType type, MapAssignment mapAssignment, SortKey[] rangeBounds) { + Preconditions.checkArgument( + (mapAssignment != null && rangeBounds == null) + || (mapAssignment == null && rangeBounds != null), + "Invalid key assignment or range bounds: both are non-null or null"); + this.checkpointId = checkpointId; + this.type = type; + this.mapAssignment = mapAssignment; + this.rangeBounds = rangeBounds; + } + + static GlobalStatistics fromMapAssignment(long checkpointId, MapAssignment mapAssignment) { + return new GlobalStatistics(checkpointId, StatisticsType.Map, mapAssignment, null); + } + + static GlobalStatistics fromRangeBounds(long checkpointId, SortKey[] rangeBounds) { + return new GlobalStatistics(checkpointId, StatisticsType.Sketch, null, rangeBounds); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("checkpointId", checkpointId) + .add("type", type) + .add("mapAssignment", mapAssignment) + .add("rangeBounds", rangeBounds) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof GlobalStatistics)) { + return false; + } + + GlobalStatistics other = (GlobalStatistics) o; + return Objects.equal(checkpointId, other.checkpointId) + && Objects.equal(type, other.type) + && Objects.equal(mapAssignment, other.mapAssignment()) + && Arrays.equals(rangeBounds, other.rangeBounds()); + } + + @Override + public int hashCode() { + // implemented caching because coordinator can call the hashCode many times. + // when subtasks request statistics refresh upon initialization for reconciliation purpose, + // hashCode is used to check if there is any difference btw coordinator and operator state. + if (hashCode == null) { + this.hashCode = Objects.hashCode(checkpointId, type, mapAssignment, rangeBounds); + } + + return hashCode; + } + + long checkpointId() { + return checkpointId; + } + + StatisticsType type() { + return type; + } + + MapAssignment mapAssignment() { + return mapAssignment; + } + + SortKey[] rangeBounds() { + return rangeBounds; + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsSerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsSerializer.java new file mode 100644 index 000000000000..dfb947a84a0c --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/GlobalStatisticsSerializer.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.EnumSerializer; +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; + +class GlobalStatisticsSerializer extends TypeSerializer { + private final TypeSerializer sortKeySerializer; + private final EnumSerializer statisticsTypeSerializer; + private final ListSerializer rangeBoundsSerializer; + private final ListSerializer intsSerializer; + private final ListSerializer longsSerializer; + + GlobalStatisticsSerializer(TypeSerializer sortKeySerializer) { + this.sortKeySerializer = sortKeySerializer; + this.statisticsTypeSerializer = new EnumSerializer<>(StatisticsType.class); + this.rangeBoundsSerializer = new ListSerializer<>(sortKeySerializer); + this.intsSerializer = new ListSerializer<>(IntSerializer.INSTANCE); + this.longsSerializer = new ListSerializer<>(LongSerializer.INSTANCE); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public TypeSerializer duplicate() { + return new GlobalStatisticsSerializer(sortKeySerializer); + } + + @Override + public GlobalStatistics createInstance() { + return GlobalStatistics.fromRangeBounds(0L, new SortKey[0]); + } + + @Override + public GlobalStatistics copy(GlobalStatistics from) { + return new GlobalStatistics( + from.checkpointId(), from.type(), from.mapAssignment(), from.rangeBounds()); + } + + @Override + public GlobalStatistics copy(GlobalStatistics from, GlobalStatistics reuse) { + // no benefit of reuse + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(GlobalStatistics record, DataOutputView target) throws IOException { + target.writeLong(record.checkpointId()); + statisticsTypeSerializer.serialize(record.type(), target); + if (record.type() == StatisticsType.Map) { + MapAssignment mapAssignment = record.mapAssignment(); + target.writeInt(mapAssignment.numPartitions()); + target.writeInt(mapAssignment.keyAssignments().size()); + for (Map.Entry entry : mapAssignment.keyAssignments().entrySet()) { + sortKeySerializer.serialize(entry.getKey(), target); + KeyAssignment keyAssignment = entry.getValue(); + intsSerializer.serialize(keyAssignment.assignedSubtasks(), target); + longsSerializer.serialize(keyAssignment.subtaskWeightsWithCloseFileCost(), target); + target.writeLong(keyAssignment.closeFileCostWeight()); + } + } else { + rangeBoundsSerializer.serialize(Arrays.asList(record.rangeBounds()), target); + } + } + + @Override + public GlobalStatistics deserialize(DataInputView source) throws IOException { + long checkpointId = source.readLong(); + StatisticsType type = statisticsTypeSerializer.deserialize(source); + if (type == StatisticsType.Map) { + int numPartitions = source.readInt(); + int mapSize = source.readInt(); + Map keyAssignments = Maps.newHashMapWithExpectedSize(mapSize); + for (int i = 0; i < mapSize; ++i) { + SortKey sortKey = sortKeySerializer.deserialize(source); + List assignedSubtasks = intsSerializer.deserialize(source); + List subtaskWeightsWithCloseFileCost = longsSerializer.deserialize(source); + long closeFileCostWeight = source.readLong(); + keyAssignments.put( + sortKey, + new KeyAssignment( + assignedSubtasks, subtaskWeightsWithCloseFileCost, closeFileCostWeight)); + } + + return GlobalStatistics.fromMapAssignment( + checkpointId, new MapAssignment(numPartitions, keyAssignments)); + } else { + List sortKeys = rangeBoundsSerializer.deserialize(source); + SortKey[] rangeBounds = new SortKey[sortKeys.size()]; + return GlobalStatistics.fromRangeBounds(checkpointId, sortKeys.toArray(rangeBounds)); + } + } + + @Override + public GlobalStatistics deserialize(GlobalStatistics reuse, DataInputView source) + throws IOException { + // not much benefit to reuse + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + serialize(deserialize(source), target); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + GlobalStatisticsSerializer other = (GlobalStatisticsSerializer) obj; + return Objects.equals(sortKeySerializer, other.sortKeySerializer); + } + + @Override + public int hashCode() { + return sortKeySerializer.hashCode(); + } + + @Override + public TypeSerializerSnapshot snapshotConfiguration() { + return new GlobalStatisticsSerializerSnapshot(this); + } + + public static class GlobalStatisticsSerializerSnapshot + extends CompositeTypeSerializerSnapshot { + private static final int CURRENT_VERSION = 1; + + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public GlobalStatisticsSerializerSnapshot() { + super(GlobalStatisticsSerializer.class); + } + + @SuppressWarnings("checkstyle:RedundantModifier") + public GlobalStatisticsSerializerSnapshot(GlobalStatisticsSerializer serializer) { + super(serializer); + } + + @Override + protected int getCurrentOuterSnapshotVersion() { + return CURRENT_VERSION; + } + + @Override + protected TypeSerializer[] getNestedSerializers(GlobalStatisticsSerializer outerSerializer) { + return new TypeSerializer[] {outerSerializer.sortKeySerializer}; + } + + @Override + protected GlobalStatisticsSerializer createOuterSerializerWithNestedSerializers( + TypeSerializer[] nestedSerializers) { + SortKeySerializer sortKeySerializer = (SortKeySerializer) nestedSerializers[0]; + return new GlobalStatisticsSerializer(sortKeySerializer); + } + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/KeyAssignment.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/KeyAssignment.java new file mode 100644 index 000000000000..a164d83ac3b0 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/KeyAssignment.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** Subtask assignment for a key for Map statistics based */ +class KeyAssignment { + private final List assignedSubtasks; + private final List subtaskWeightsWithCloseFileCost; + private final long closeFileCostWeight; + private final long[] subtaskWeightsExcludingCloseCost; + private final long keyWeight; + private final long[] cumulativeWeights; + + /** + * @param assignedSubtasks assigned subtasks for this key. It could be a single subtask. It could + * also be multiple subtasks if the key has heavy weight that should be handled by multiple + * subtasks. + * @param subtaskWeightsWithCloseFileCost assigned weight for each subtask. E.g., if the keyWeight + * is 27 and the key is assigned to 3 subtasks, subtaskWeights could contain values as [10, + * 10, 7] for target weight of 10 per subtask. + */ + KeyAssignment( + List assignedSubtasks, + List subtaskWeightsWithCloseFileCost, + long closeFileCostWeight) { + Preconditions.checkArgument( + assignedSubtasks != null && !assignedSubtasks.isEmpty(), + "Invalid assigned subtasks: null or empty"); + Preconditions.checkArgument( + subtaskWeightsWithCloseFileCost != null && !subtaskWeightsWithCloseFileCost.isEmpty(), + "Invalid assigned subtasks weights: null or empty"); + Preconditions.checkArgument( + assignedSubtasks.size() == subtaskWeightsWithCloseFileCost.size(), + "Invalid assignment: size mismatch (tasks length = %s, weights length = %s)", + assignedSubtasks.size(), + subtaskWeightsWithCloseFileCost.size()); + subtaskWeightsWithCloseFileCost.forEach( + weight -> + Preconditions.checkArgument( + weight > closeFileCostWeight, + "Invalid weight: should be larger than close file cost: weight = %s, close file cost = %s", + weight, + closeFileCostWeight)); + + this.assignedSubtasks = assignedSubtasks; + this.subtaskWeightsWithCloseFileCost = subtaskWeightsWithCloseFileCost; + this.closeFileCostWeight = closeFileCostWeight; + // Exclude the close file cost for key routing + this.subtaskWeightsExcludingCloseCost = + subtaskWeightsWithCloseFileCost.stream() + .mapToLong(weightWithCloseFileCost -> weightWithCloseFileCost - closeFileCostWeight) + .toArray(); + this.keyWeight = Arrays.stream(subtaskWeightsExcludingCloseCost).sum(); + this.cumulativeWeights = new long[subtaskWeightsExcludingCloseCost.length]; + long cumulativeWeight = 0; + for (int i = 0; i < subtaskWeightsExcludingCloseCost.length; ++i) { + cumulativeWeight += subtaskWeightsExcludingCloseCost[i]; + cumulativeWeights[i] = cumulativeWeight; + } + } + + List assignedSubtasks() { + return assignedSubtasks; + } + + List subtaskWeightsWithCloseFileCost() { + return subtaskWeightsWithCloseFileCost; + } + + long closeFileCostWeight() { + return closeFileCostWeight; + } + + long[] subtaskWeightsExcludingCloseCost() { + return subtaskWeightsExcludingCloseCost; + } + + /** @return subtask id */ + int select() { + if (assignedSubtasks.size() == 1) { + // only choice. no need to run random number generator. + return assignedSubtasks.get(0); + } else { + long randomNumber = ThreadLocalRandom.current().nextLong(keyWeight); + int index = Arrays.binarySearch(cumulativeWeights, randomNumber); + // choose the subtask where randomNumber < cumulativeWeights[pos]. + // this works regardless whether index is negative or not. + int position = Math.abs(index + 1); + Preconditions.checkState( + position < assignedSubtasks.size(), + "Invalid selected position: out of range. key weight = %s, random number = %s, cumulative weights array = %s", + keyWeight, + randomNumber, + cumulativeWeights); + return assignedSubtasks.get(position); + } + } + + @Override + public int hashCode() { + return Objects.hash(assignedSubtasks, subtaskWeightsWithCloseFileCost, closeFileCostWeight); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + KeyAssignment that = (KeyAssignment) o; + return Objects.equals(assignedSubtasks, that.assignedSubtasks) + && Objects.equals(subtaskWeightsWithCloseFileCost, that.subtaskWeightsWithCloseFileCost) + && closeFileCostWeight == that.closeFileCostWeight; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("assignedSubtasks", assignedSubtasks) + .add("subtaskWeightsWithCloseFileCost", subtaskWeightsWithCloseFileCost) + .add("closeFileCostWeight", closeFileCostWeight) + .toString(); + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapAssignment.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapAssignment.java new file mode 100644 index 000000000000..0abb030c2279 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapAssignment.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.util.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Key assignment to subtasks for Map statistics. */ +class MapAssignment { + private static final Logger LOG = LoggerFactory.getLogger(MapAssignment.class); + + private final int numPartitions; + private final Map keyAssignments; + + MapAssignment(int numPartitions, Map keyAssignments) { + Preconditions.checkArgument(keyAssignments != null, "Invalid key assignments: null"); + this.numPartitions = numPartitions; + this.keyAssignments = keyAssignments; + } + + static MapAssignment fromKeyFrequency( + int numPartitions, + Map mapStatistics, + double closeFileCostWeightPercentage, + Comparator comparator) { + return new MapAssignment( + numPartitions, + assignment(numPartitions, mapStatistics, closeFileCostWeightPercentage, comparator)); + } + + @Override + public int hashCode() { + return Objects.hashCode(numPartitions, keyAssignments); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + MapAssignment that = (MapAssignment) o; + return numPartitions == that.numPartitions && keyAssignments.equals(that.keyAssignments); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("numPartitions", numPartitions) + .add("keyAssignments", keyAssignments) + .toString(); + } + + int numPartitions() { + return numPartitions; + } + + Map keyAssignments() { + return keyAssignments; + } + + /** + * @return assignment summary for every subtask. Key is subtaskId. Value pair is (weight assigned + * to the subtask, number of keys assigned to the subtask) + */ + Map> assignmentInfo() { + Map> assignmentInfo = Maps.newTreeMap(); + keyAssignments.forEach( + (key, keyAssignment) -> { + for (int i = 0; i < keyAssignment.assignedSubtasks().size(); ++i) { + int subtaskId = keyAssignment.assignedSubtasks().get(i); + long subtaskWeight = keyAssignment.subtaskWeightsExcludingCloseCost()[i]; + Pair oldValue = assignmentInfo.getOrDefault(subtaskId, Pair.of(0L, 0)); + assignmentInfo.put( + subtaskId, Pair.of(oldValue.first() + subtaskWeight, oldValue.second() + 1)); + } + }); + + return assignmentInfo; + } + + static Map assignment( + int numPartitions, + Map mapStatistics, + double closeFileCostWeightPercentage, + Comparator comparator) { + mapStatistics.forEach( + (key, value) -> + Preconditions.checkArgument( + value > 0, "Invalid statistics: weight is 0 for key %s", key)); + + long totalWeight = mapStatistics.values().stream().mapToLong(l -> l).sum(); + double targetWeightPerSubtask = ((double) totalWeight) / numPartitions; + long closeFileCostWeight = + (long) Math.ceil(targetWeightPerSubtask * closeFileCostWeightPercentage / 100); + + NavigableMap sortedStatsWithCloseFileCost = Maps.newTreeMap(comparator); + mapStatistics.forEach( + (k, v) -> { + int estimatedSplits = (int) Math.ceil(v / targetWeightPerSubtask); + long estimatedCloseFileCost = closeFileCostWeight * estimatedSplits; + sortedStatsWithCloseFileCost.put(k, v + estimatedCloseFileCost); + }); + + long totalWeightWithCloseFileCost = + sortedStatsWithCloseFileCost.values().stream().mapToLong(l -> l).sum(); + long targetWeightPerSubtaskWithCloseFileCost = + (long) Math.ceil(((double) totalWeightWithCloseFileCost) / numPartitions); + return buildAssignment( + numPartitions, + sortedStatsWithCloseFileCost, + targetWeightPerSubtaskWithCloseFileCost, + closeFileCostWeight); + } + + private static Map buildAssignment( + int numPartitions, + NavigableMap sortedStatistics, + long targetWeightPerSubtask, + long closeFileCostWeight) { + Map assignmentMap = + Maps.newHashMapWithExpectedSize(sortedStatistics.size()); + Iterator mapKeyIterator = sortedStatistics.keySet().iterator(); + int subtaskId = 0; + SortKey currentKey = null; + long keyRemainingWeight = 0L; + long subtaskRemainingWeight = targetWeightPerSubtask; + List assignedSubtasks = Lists.newArrayList(); + List subtaskWeights = Lists.newArrayList(); + while (mapKeyIterator.hasNext() || currentKey != null) { + // This should never happen because target weight is calculated using ceil function. + if (subtaskId >= numPartitions) { + LOG.error( + "Internal algorithm error: exhausted subtasks with unassigned keys left. number of partitions: {}, " + + "target weight per subtask: {}, close file cost in weight: {}, data statistics: {}", + numPartitions, + targetWeightPerSubtask, + closeFileCostWeight, + sortedStatistics); + throw new IllegalStateException( + "Internal algorithm error: exhausted subtasks with unassigned keys left"); + } + + if (currentKey == null) { + currentKey = mapKeyIterator.next(); + keyRemainingWeight = sortedStatistics.get(currentKey); + } + + assignedSubtasks.add(subtaskId); + if (keyRemainingWeight < subtaskRemainingWeight) { + // assign the remaining weight of the key to the current subtask + subtaskWeights.add(keyRemainingWeight); + subtaskRemainingWeight -= keyRemainingWeight; + keyRemainingWeight = 0L; + } else { + // filled up the current subtask + long assignedWeight = subtaskRemainingWeight; + keyRemainingWeight -= subtaskRemainingWeight; + + // If assigned weight is less than close file cost, pad it up with close file cost. + // This might cause the subtask assigned weight over the target weight. + // But it should be no more than one close file cost. Small skew is acceptable. + if (assignedWeight <= closeFileCostWeight) { + long paddingWeight = Math.min(keyRemainingWeight, closeFileCostWeight); + keyRemainingWeight -= paddingWeight; + assignedWeight += paddingWeight; + } + + subtaskWeights.add(assignedWeight); + // move on to the next subtask + subtaskId += 1; + subtaskRemainingWeight = targetWeightPerSubtask; + } + + Preconditions.checkState( + assignedSubtasks.size() == subtaskWeights.size(), + "List size mismatch: assigned subtasks = %s, subtask weights = %s", + assignedSubtasks, + subtaskWeights); + + // If the remaining key weight is smaller than the close file cost, simply skip the residual + // as it doesn't make sense to assign a weight smaller than close file cost to a new subtask. + // this might lead to some inaccuracy in weight calculation. E.g., assuming the key weight is + // 2 and close file cost is 2. key weight with close cost is 4. Let's assume the previous + // task has a weight of 3 available. So weight of 3 for this key is assigned to the task and + // the residual weight of 1 is dropped. Then the routing weight for this key is 1 (minus the + // close file cost), which is inaccurate as the true key weight should be 2. + // Again, this greedy algorithm is not intended to be perfect. Some small inaccuracy is + // expected and acceptable. Traffic distribution should still be balanced. + if (keyRemainingWeight > 0 && keyRemainingWeight <= closeFileCostWeight) { + keyRemainingWeight = 0; + } + + if (keyRemainingWeight == 0) { + // finishing up the assignment for the current key + KeyAssignment keyAssignment = + new KeyAssignment(assignedSubtasks, subtaskWeights, closeFileCostWeight); + assignmentMap.put(currentKey, keyAssignment); + assignedSubtasks = Lists.newArrayList(); + subtaskWeights = Lists.newArrayList(); + currentKey = null; + } + } + + return assignmentMap; + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java index 0b63e2721178..05b943f6046f 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java @@ -19,52 +19,70 @@ package org.apache.iceberg.flink.sink.shuffle; import java.util.Map; -import org.apache.flink.annotation.Internal; import org.apache.iceberg.SortKey; import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; import org.apache.iceberg.relocated.com.google.common.collect.Maps; /** MapDataStatistics uses map to count key frequency */ -@Internal -class MapDataStatistics implements DataStatistics> { - private final Map statistics; +class MapDataStatistics implements DataStatistics { + private final Map keyFrequency; MapDataStatistics() { - this.statistics = Maps.newHashMap(); + this.keyFrequency = Maps.newHashMap(); } - MapDataStatistics(Map statistics) { - this.statistics = statistics; + MapDataStatistics(Map keyFrequency) { + this.keyFrequency = keyFrequency; + } + + @Override + public StatisticsType type() { + return StatisticsType.Map; } @Override public boolean isEmpty() { - return statistics.isEmpty(); + return keyFrequency.isEmpty(); } @Override public void add(SortKey sortKey) { - if (statistics.containsKey(sortKey)) { - statistics.merge(sortKey, 1L, Long::sum); + if (keyFrequency.containsKey(sortKey)) { + keyFrequency.merge(sortKey, 1L, Long::sum); } else { // clone the sort key before adding to map because input sortKey object can be reused SortKey copiedKey = sortKey.copy(); - statistics.put(copiedKey, 1L); + keyFrequency.put(copiedKey, 1L); } } @Override - public void merge(MapDataStatistics otherStatistics) { - otherStatistics.statistics().forEach((key, count) -> statistics.merge(key, count, Long::sum)); + public Object result() { + return keyFrequency; } @Override - public Map statistics() { - return statistics; + public String toString() { + return MoreObjects.toStringHelper(this).add("map", keyFrequency).toString(); } @Override - public String toString() { - return MoreObjects.toStringHelper(this).add("statistics", statistics).toString(); + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof MapDataStatistics)) { + return false; + } + + MapDataStatistics other = (MapDataStatistics) o; + return Objects.equal(keyFrequency, other.keyFrequency); + } + + @Override + public int hashCode() { + return Objects.hashCode(keyFrequency); } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java deleted file mode 100644 index b6cccd0566fc..000000000000 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import java.io.IOException; -import java.util.Map; -import java.util.Objects; -import org.apache.flink.annotation.Internal; -import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; -import org.apache.flink.api.common.typeutils.base.LongSerializer; -import org.apache.flink.api.common.typeutils.base.MapSerializer; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.util.Preconditions; -import org.apache.iceberg.SortKey; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; - -@Internal -class MapDataStatisticsSerializer - extends TypeSerializer>> { - private final MapSerializer mapSerializer; - - static MapDataStatisticsSerializer fromSortKeySerializer( - TypeSerializer sortKeySerializer) { - return new MapDataStatisticsSerializer( - new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE)); - } - - MapDataStatisticsSerializer(MapSerializer mapSerializer) { - this.mapSerializer = mapSerializer; - } - - @Override - public boolean isImmutableType() { - return false; - } - - @SuppressWarnings("ReferenceEquality") - @Override - public TypeSerializer>> duplicate() { - MapSerializer duplicateMapSerializer = - (MapSerializer) mapSerializer.duplicate(); - return (duplicateMapSerializer == mapSerializer) - ? this - : new MapDataStatisticsSerializer(duplicateMapSerializer); - } - - @Override - public MapDataStatistics createInstance() { - return new MapDataStatistics(); - } - - @Override - public MapDataStatistics copy(DataStatistics> obj) { - Preconditions.checkArgument( - obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass()); - MapDataStatistics from = (MapDataStatistics) obj; - TypeSerializer keySerializer = mapSerializer.getKeySerializer(); - Map newMap = Maps.newHashMapWithExpectedSize(from.statistics().size()); - for (Map.Entry entry : from.statistics().entrySet()) { - SortKey newKey = keySerializer.copy(entry.getKey()); - // no need to copy value since it is just a Long - newMap.put(newKey, entry.getValue()); - } - - return new MapDataStatistics(newMap); - } - - @Override - public DataStatistics> copy( - DataStatistics> from, - DataStatistics> reuse) { - // not much benefit to reuse - return copy(from); - } - - @Override - public int getLength() { - return -1; - } - - @Override - public void serialize( - DataStatistics> obj, DataOutputView target) - throws IOException { - Preconditions.checkArgument( - obj instanceof MapDataStatistics, "Invalid data statistics type: " + obj.getClass()); - MapDataStatistics mapStatistics = (MapDataStatistics) obj; - mapSerializer.serialize(mapStatistics.statistics(), target); - } - - @Override - public DataStatistics> deserialize(DataInputView source) - throws IOException { - return new MapDataStatistics(mapSerializer.deserialize(source)); - } - - @Override - public DataStatistics> deserialize( - DataStatistics> reuse, DataInputView source) - throws IOException { - // not much benefit to reuse - return deserialize(source); - } - - @Override - public void copy(DataInputView source, DataOutputView target) throws IOException { - mapSerializer.copy(source, target); - } - - @Override - public boolean equals(Object obj) { - if (!(obj instanceof MapDataStatisticsSerializer)) { - return false; - } - - MapDataStatisticsSerializer other = (MapDataStatisticsSerializer) obj; - return Objects.equals(mapSerializer, other.mapSerializer); - } - - @Override - public int hashCode() { - return mapSerializer.hashCode(); - } - - @Override - public TypeSerializerSnapshot>> - snapshotConfiguration() { - return new MapDataStatisticsSerializerSnapshot(this); - } - - public static class MapDataStatisticsSerializerSnapshot - extends CompositeTypeSerializerSnapshot< - DataStatistics>, MapDataStatisticsSerializer> { - private static final int CURRENT_VERSION = 1; - - // constructors need to public. Otherwise, Flink state restore would complain - // "The class has no (implicit) public nullary constructor". - @SuppressWarnings("checkstyle:RedundantModifier") - public MapDataStatisticsSerializerSnapshot() { - super(MapDataStatisticsSerializer.class); - } - - @SuppressWarnings("checkstyle:RedundantModifier") - public MapDataStatisticsSerializerSnapshot(MapDataStatisticsSerializer serializer) { - super(serializer); - } - - @Override - protected int getCurrentOuterSnapshotVersion() { - return CURRENT_VERSION; - } - - @Override - protected TypeSerializer[] getNestedSerializers( - MapDataStatisticsSerializer outerSerializer) { - return new TypeSerializer[] {outerSerializer.mapSerializer}; - } - - @Override - protected MapDataStatisticsSerializer createOuterSerializerWithNestedSerializers( - TypeSerializer[] nestedSerializers) { - @SuppressWarnings("unchecked") - MapSerializer mapSerializer = - (MapSerializer) nestedSerializers[0]; - return new MapDataStatisticsSerializer(mapSerializer); - } - } -} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java index dde86b5b6047..f36a078c94e0 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapRangePartitioner.java @@ -18,29 +18,14 @@ */ package org.apache.iceberg.flink.sink.shuffle; -import java.util.Arrays; -import java.util.Comparator; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.NavigableMap; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.table.data.RowData; import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; import org.apache.iceberg.SortOrder; -import org.apache.iceberg.SortOrderComparators; -import org.apache.iceberg.StructLike; import org.apache.iceberg.flink.FlinkSchemaUtil; import org.apache.iceberg.flink.RowDataWrapper; -import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; -import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; -import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.util.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,52 +46,28 @@ class MapRangePartitioner implements Partitioner { private final RowDataWrapper rowDataWrapper; private final SortKey sortKey; - private final Comparator comparator; - private final Map mapStatistics; - private final double closeFileCostInWeightPercentage; + private final MapAssignment mapAssignment; // Counter that tracks how many times a new key encountered // where there is no traffic statistics learned about it. private long newSortKeyCounter; private long lastNewSortKeyLogTimeMilli; - // lazily computed due to the need of numPartitions - private Map assignment; - private NavigableMap sortedStatsWithCloseFileCost; - - MapRangePartitioner( - Schema schema, - SortOrder sortOrder, - MapDataStatistics dataStatistics, - double closeFileCostInWeightPercentage) { - dataStatistics - .statistics() - .entrySet() - .forEach( - entry -> - Preconditions.checkArgument( - entry.getValue() > 0, - "Invalid statistics: weight is 0 for key %s", - entry.getKey())); - + MapRangePartitioner(Schema schema, SortOrder sortOrder, MapAssignment mapAssignment) { this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); this.sortKey = new SortKey(schema, sortOrder); - this.comparator = SortOrderComparators.forSchema(schema, sortOrder); - this.mapStatistics = dataStatistics.statistics(); - this.closeFileCostInWeightPercentage = closeFileCostInWeightPercentage; + this.mapAssignment = mapAssignment; this.newSortKeyCounter = 0; this.lastNewSortKeyLogTimeMilli = System.currentTimeMillis(); } @Override public int partition(RowData row, int numPartitions) { - // assignment table can only be built lazily when first referenced here, - // because number of partitions (downstream subtasks) is needed. - // the numPartitions is not available in the constructor. - Map assignmentMap = assignment(numPartitions); // reuse the sortKey and rowDataWrapper sortKey.wrap(rowDataWrapper.wrap(row)); - KeyAssignment keyAssignment = assignmentMap.get(sortKey); + KeyAssignment keyAssignment = mapAssignment.keyAssignments().get(sortKey); + + int partition; if (keyAssignment == null) { LOG.trace( "Encountered new sort key: {}. Fall back to round robin as statistics not learned yet.", @@ -117,271 +78,18 @@ public int partition(RowData row, int numPartitions) { newSortKeyCounter += 1; long now = System.currentTimeMillis(); if (now - lastNewSortKeyLogTimeMilli > TimeUnit.MINUTES.toMillis(1)) { - LOG.info("Encounter new sort keys in total {} times", newSortKeyCounter); + LOG.info( + "Encounter new sort keys {} times. Fall back to round robin as statistics not learned yet", + newSortKeyCounter); lastNewSortKeyLogTimeMilli = now; + newSortKeyCounter = 0; } - return (int) (newSortKeyCounter % numPartitions); + partition = (int) (newSortKeyCounter % numPartitions); + } else { + partition = keyAssignment.select(); } - return keyAssignment.select(); - } - - @VisibleForTesting - Map assignment(int numPartitions) { - if (assignment == null) { - long totalWeight = mapStatistics.values().stream().mapToLong(l -> l).sum(); - double targetWeightPerSubtask = ((double) totalWeight) / numPartitions; - long closeFileCostInWeight = - (long) Math.ceil(targetWeightPerSubtask * closeFileCostInWeightPercentage / 100); - - this.sortedStatsWithCloseFileCost = Maps.newTreeMap(comparator); - mapStatistics.forEach( - (k, v) -> { - int estimatedSplits = (int) Math.ceil(v / targetWeightPerSubtask); - long estimatedCloseFileCost = closeFileCostInWeight * estimatedSplits; - sortedStatsWithCloseFileCost.put(k, v + estimatedCloseFileCost); - }); - - long totalWeightWithCloseFileCost = - sortedStatsWithCloseFileCost.values().stream().mapToLong(l -> l).sum(); - long targetWeightPerSubtaskWithCloseFileCost = - (long) Math.ceil(((double) totalWeightWithCloseFileCost) / numPartitions); - this.assignment = - buildAssignment( - numPartitions, - sortedStatsWithCloseFileCost, - targetWeightPerSubtaskWithCloseFileCost, - closeFileCostInWeight); - } - - return assignment; - } - - @VisibleForTesting - Map mapStatistics() { - return mapStatistics; - } - - /** - * Returns assignment summary for every subtask. - * - * @return assignment summary for every subtask. Key is subtaskId. Value pair is (weight assigned - * to the subtask, number of keys assigned to the subtask) - */ - Map> assignmentInfo() { - Map> assignmentInfo = Maps.newTreeMap(); - assignment.forEach( - (key, keyAssignment) -> { - for (int i = 0; i < keyAssignment.assignedSubtasks.length; ++i) { - int subtaskId = keyAssignment.assignedSubtasks[i]; - long subtaskWeight = keyAssignment.subtaskWeightsExcludingCloseCost[i]; - Pair oldValue = assignmentInfo.getOrDefault(subtaskId, Pair.of(0L, 0)); - assignmentInfo.put( - subtaskId, Pair.of(oldValue.first() + subtaskWeight, oldValue.second() + 1)); - } - }); - - return assignmentInfo; - } - - private Map buildAssignment( - int numPartitions, - NavigableMap sortedStatistics, - long targetWeightPerSubtask, - long closeFileCostInWeight) { - Map assignmentMap = - Maps.newHashMapWithExpectedSize(sortedStatistics.size()); - Iterator mapKeyIterator = sortedStatistics.keySet().iterator(); - int subtaskId = 0; - SortKey currentKey = null; - long keyRemainingWeight = 0L; - long subtaskRemainingWeight = targetWeightPerSubtask; - List assignedSubtasks = Lists.newArrayList(); - List subtaskWeights = Lists.newArrayList(); - while (mapKeyIterator.hasNext() || currentKey != null) { - // This should never happen because target weight is calculated using ceil function. - if (subtaskId >= numPartitions) { - LOG.error( - "Internal algorithm error: exhausted subtasks with unassigned keys left. number of partitions: {}, " - + "target weight per subtask: {}, close file cost in weight: {}, data statistics: {}", - numPartitions, - targetWeightPerSubtask, - closeFileCostInWeight, - sortedStatistics); - throw new IllegalStateException( - "Internal algorithm error: exhausted subtasks with unassigned keys left"); - } - - if (currentKey == null) { - currentKey = mapKeyIterator.next(); - keyRemainingWeight = sortedStatistics.get(currentKey); - } - - assignedSubtasks.add(subtaskId); - if (keyRemainingWeight < subtaskRemainingWeight) { - // assign the remaining weight of the key to the current subtask - subtaskWeights.add(keyRemainingWeight); - subtaskRemainingWeight -= keyRemainingWeight; - keyRemainingWeight = 0L; - } else { - // filled up the current subtask - long assignedWeight = subtaskRemainingWeight; - keyRemainingWeight -= subtaskRemainingWeight; - - // If assigned weight is less than close file cost, pad it up with close file cost. - // This might cause the subtask assigned weight over the target weight. - // But it should be no more than one close file cost. Small skew is acceptable. - if (assignedWeight <= closeFileCostInWeight) { - long paddingWeight = Math.min(keyRemainingWeight, closeFileCostInWeight); - keyRemainingWeight -= paddingWeight; - assignedWeight += paddingWeight; - } - - subtaskWeights.add(assignedWeight); - // move on to the next subtask - subtaskId += 1; - subtaskRemainingWeight = targetWeightPerSubtask; - } - - Preconditions.checkState( - assignedSubtasks.size() == subtaskWeights.size(), - "List size mismatch: assigned subtasks = %s, subtask weights = %s", - assignedSubtasks, - subtaskWeights); - - // If the remaining key weight is smaller than the close file cost, simply skip the residual - // as it doesn't make sense to assign a weight smaller than close file cost to a new subtask. - // this might lead to some inaccuracy in weight calculation. E.g., assuming the key weight is - // 2 and close file cost is 2. key weight with close cost is 4. Let's assume the previous - // task has a weight of 3 available. So weight of 3 for this key is assigned to the task and - // the residual weight of 1 is dropped. Then the routing weight for this key is 1 (minus the - // close file cost), which is inaccurate as the true key weight should be 2. - // Again, this greedy algorithm is not intended to be perfect. Some small inaccuracy is - // expected and acceptable. Traffic distribution should still be balanced. - if (keyRemainingWeight > 0 && keyRemainingWeight <= closeFileCostInWeight) { - keyRemainingWeight = 0; - } - - if (keyRemainingWeight == 0) { - // finishing up the assignment for the current key - KeyAssignment keyAssignment = - new KeyAssignment(assignedSubtasks, subtaskWeights, closeFileCostInWeight); - assignmentMap.put(currentKey, keyAssignment); - assignedSubtasks.clear(); - subtaskWeights.clear(); - currentKey = null; - } - } - - return assignmentMap; - } - - /** Subtask assignment for a key */ - @VisibleForTesting - static class KeyAssignment { - private final int[] assignedSubtasks; - private final long[] subtaskWeightsExcludingCloseCost; - private final long keyWeight; - private final long[] cumulativeWeights; - - /** - * @param assignedSubtasks assigned subtasks for this key. It could be a single subtask. It - * could also be multiple subtasks if the key has heavy weight that should be handled by - * multiple subtasks. - * @param subtaskWeightsWithCloseFileCost assigned weight for each subtask. E.g., if the - * keyWeight is 27 and the key is assigned to 3 subtasks, subtaskWeights could contain - * values as [10, 10, 7] for target weight of 10 per subtask. - */ - KeyAssignment( - List assignedSubtasks, - List subtaskWeightsWithCloseFileCost, - long closeFileCostInWeight) { - Preconditions.checkArgument( - assignedSubtasks != null && !assignedSubtasks.isEmpty(), - "Invalid assigned subtasks: null or empty"); - Preconditions.checkArgument( - subtaskWeightsWithCloseFileCost != null && !subtaskWeightsWithCloseFileCost.isEmpty(), - "Invalid assigned subtasks weights: null or empty"); - Preconditions.checkArgument( - assignedSubtasks.size() == subtaskWeightsWithCloseFileCost.size(), - "Invalid assignment: size mismatch (tasks length = %s, weights length = %s)", - assignedSubtasks.size(), - subtaskWeightsWithCloseFileCost.size()); - subtaskWeightsWithCloseFileCost.forEach( - weight -> - Preconditions.checkArgument( - weight > closeFileCostInWeight, - "Invalid weight: should be larger than close file cost: weight = %s, close file cost = %s", - weight, - closeFileCostInWeight)); - - this.assignedSubtasks = assignedSubtasks.stream().mapToInt(i -> i).toArray(); - // Exclude the close file cost for key routing - this.subtaskWeightsExcludingCloseCost = - subtaskWeightsWithCloseFileCost.stream() - .mapToLong(weightWithCloseFileCost -> weightWithCloseFileCost - closeFileCostInWeight) - .toArray(); - this.keyWeight = Arrays.stream(subtaskWeightsExcludingCloseCost).sum(); - this.cumulativeWeights = new long[subtaskWeightsExcludingCloseCost.length]; - long cumulativeWeight = 0; - for (int i = 0; i < subtaskWeightsExcludingCloseCost.length; ++i) { - cumulativeWeight += subtaskWeightsExcludingCloseCost[i]; - cumulativeWeights[i] = cumulativeWeight; - } - } - - /** - * Select a subtask for the key. - * - * @return subtask id - */ - int select() { - if (assignedSubtasks.length == 1) { - // only choice. no need to run random number generator. - return assignedSubtasks[0]; - } else { - long randomNumber = ThreadLocalRandom.current().nextLong(keyWeight); - int index = Arrays.binarySearch(cumulativeWeights, randomNumber); - // choose the subtask where randomNumber < cumulativeWeights[pos]. - // this works regardless whether index is negative or not. - int position = Math.abs(index + 1); - Preconditions.checkState( - position < assignedSubtasks.length, - "Invalid selected position: out of range. key weight = %s, random number = %s, cumulative weights array = %s", - keyWeight, - randomNumber, - cumulativeWeights); - return assignedSubtasks[position]; - } - } - - @Override - public int hashCode() { - return 31 * Arrays.hashCode(assignedSubtasks) - + Arrays.hashCode(subtaskWeightsExcludingCloseCost); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - - if (o == null || getClass() != o.getClass()) { - return false; - } - - KeyAssignment that = (KeyAssignment) o; - return Arrays.equals(assignedSubtasks, that.assignedSubtasks) - && Arrays.equals(subtaskWeightsExcludingCloseCost, that.subtaskWeightsExcludingCloseCost); - } - - @Override - public String toString() { - return MoreObjects.toStringHelper(this) - .add("assignedSubtasks", assignedSubtasks) - .add("subtaskWeightsExcludingCloseCost", subtaskWeightsExcludingCloseCost) - .toString(); - } + return RangePartitioner.adjustPartitionWithRescale( + partition, mapAssignment.numPartitions(), numPartitions); } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RangePartitioner.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RangePartitioner.java new file mode 100644 index 000000000000..83a9461233d2 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RangePartitioner.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.table.data.RowData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortOrder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** The wrapper class */ +@Internal +public class RangePartitioner implements Partitioner { + private static final Logger LOG = LoggerFactory.getLogger(RangePartitioner.class); + + private final Schema schema; + private final SortOrder sortOrder; + + private transient AtomicLong roundRobinCounter; + private transient Partitioner delegatePartitioner; + + public RangePartitioner(Schema schema, SortOrder sortOrder) { + this.schema = schema; + this.sortOrder = sortOrder; + } + + @Override + public int partition(StatisticsOrRecord wrapper, int numPartitions) { + if (wrapper.hasStatistics()) { + this.delegatePartitioner = delegatePartitioner(wrapper.statistics()); + return (int) (roundRobinCounter(numPartitions).getAndIncrement() % numPartitions); + } else { + if (delegatePartitioner != null) { + return delegatePartitioner.partition(wrapper.record(), numPartitions); + } else { + int partition = (int) (roundRobinCounter(numPartitions).getAndIncrement() % numPartitions); + LOG.trace("Statistics not available. Round robin to partition {}", partition); + return partition; + } + } + } + + private AtomicLong roundRobinCounter(int numPartitions) { + if (roundRobinCounter == null) { + // randomize the starting point to avoid synchronization across subtasks + this.roundRobinCounter = new AtomicLong(new Random().nextInt(numPartitions)); + } + + return roundRobinCounter; + } + + private Partitioner delegatePartitioner(GlobalStatistics statistics) { + if (statistics.type() == StatisticsType.Map) { + return new MapRangePartitioner(schema, sortOrder, statistics.mapAssignment()); + } else if (statistics.type() == StatisticsType.Sketch) { + return new SketchRangePartitioner(schema, sortOrder, statistics.rangeBounds()); + } else { + throw new IllegalArgumentException( + String.format("Invalid statistics type: %s. Should be Map or Sketch", statistics.type())); + } + } + + /** + * Util method that handles rescale (write parallelism / numPartitions change). + * + * @param partition partition caculated based on the existing statistics + * @param numPartitionsStatsCalculation number of partitions when the assignment was calculated + * based on + * @param numPartitions current number of partitions + * @return adjusted partition if necessary. + */ + static int adjustPartitionWithRescale( + int partition, int numPartitionsStatsCalculation, int numPartitions) { + if (numPartitionsStatsCalculation <= numPartitions) { + // no rescale or scale-up case. + // new subtasks are ignored and not assigned any keys, which is sub-optimal and only + // transient. + // when rescale is detected, operator requests new statistics from coordinator upon + // initialization. + return partition; + } else { + // scale-down case. + // Use mod % operation to distribution the over-range partitions. + // It can cause skew among subtasks. but the behavior is still better than + // discarding the statistics and falling back to round-robin (no clustering). + // Again, this is transient and stats refresh is requested when rescale is detected. + return partition % numPartitions; + } + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RequestGlobalStatisticsEvent.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RequestGlobalStatisticsEvent.java new file mode 100644 index 000000000000..ce17e1964392 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/RequestGlobalStatisticsEvent.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import org.apache.flink.runtime.operators.coordination.OperatorEvent; + +class RequestGlobalStatisticsEvent implements OperatorEvent { + private final Integer signature; + + RequestGlobalStatisticsEvent() { + this.signature = null; + } + + /** @param signature hashCode of the subtask's existing global statistics */ + RequestGlobalStatisticsEvent(int signature) { + this.signature = signature; + } + + Integer signature() { + return signature; + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchDataStatistics.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchDataStatistics.java new file mode 100644 index 000000000000..35bbb27baf16 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchDataStatistics.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; +import org.apache.iceberg.relocated.com.google.common.base.Objects; + +/** MapDataStatistics uses map to count key frequency */ +class SketchDataStatistics implements DataStatistics { + + private final ReservoirItemsSketch sketch; + + SketchDataStatistics(int reservoirSize) { + this.sketch = ReservoirItemsSketch.newInstance(reservoirSize); + } + + SketchDataStatistics(ReservoirItemsSketch sketchStats) { + this.sketch = sketchStats; + } + + @Override + public StatisticsType type() { + return StatisticsType.Sketch; + } + + @Override + public boolean isEmpty() { + return sketch.getNumSamples() == 0; + } + + @Override + public void add(SortKey sortKey) { + // clone the sort key first because input sortKey object can be reused + SortKey copiedKey = sortKey.copy(); + sketch.update(copiedKey); + } + + @Override + public Object result() { + return sketch; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("sketch", sketch).toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof SketchDataStatistics)) { + return false; + } + + ReservoirItemsSketch otherSketch = ((SketchDataStatistics) o).sketch; + return Objects.equal(sketch.getK(), otherSketch.getK()) + && Objects.equal(sketch.getN(), otherSketch.getN()) + && Arrays.deepEquals(sketch.getSamples(), otherSketch.getSamples()); + } + + @Override + public int hashCode() { + return Objects.hashCode(sketch.getK(), sketch.getN(), sketch.getSamples()); + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitioner.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitioner.java new file mode 100644 index 000000000000..af78271ea5dc --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchRangePartitioner.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.Comparator; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.table.data.RowData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.flink.FlinkSchemaUtil; +import org.apache.iceberg.flink.RowDataWrapper; + +class SketchRangePartitioner implements Partitioner { + private final SortKey sortKey; + private final Comparator comparator; + private final SortKey[] rangeBounds; + private final RowDataWrapper rowDataWrapper; + + SketchRangePartitioner(Schema schema, SortOrder sortOrder, SortKey[] rangeBounds) { + this.sortKey = new SortKey(schema, sortOrder); + this.comparator = SortOrderComparators.forSchema(schema, sortOrder); + this.rangeBounds = rangeBounds; + this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), schema.asStruct()); + } + + @Override + public int partition(RowData row, int numPartitions) { + // reuse the sortKey and rowDataWrapper + sortKey.wrap(rowDataWrapper.wrap(row)); + int partition = Arrays.binarySearch(rangeBounds, sortKey, comparator); + + // binarySearch either returns the match location or -[insertion point]-1 + if (partition < 0) { + partition = -partition - 1; + } + + if (partition > rangeBounds.length) { + partition = rangeBounds.length; + } + + return RangePartitioner.adjustPartitionWithRescale( + partition, rangeBounds.length + 1, numPartitions); + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchUtil.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchUtil.java new file mode 100644 index 000000000000..a58310611e8d --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SketchUtil.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Map; +import java.util.function.Consumer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.StructLike; + +class SketchUtil { + static final int COORDINATOR_MIN_RESERVOIR_SIZE = 10_000; + static final int COORDINATOR_MAX_RESERVOIR_SIZE = 1_000_000; + static final int COORDINATOR_TARGET_PARTITIONS_MULTIPLIER = 100; + static final int OPERATOR_OVER_SAMPLE_RATIO = 10; + + // switch the statistics tracking from map to sketch if the cardinality of the sort key is over + // this threshold. It is hardcoded for now, we can revisit in the future if config is needed. + static final int OPERATOR_SKETCH_SWITCH_THRESHOLD = 10_000; + static final int COORDINATOR_SKETCH_SWITCH_THRESHOLD = 100_000; + + private SketchUtil() {} + + /** + * The larger the reservoir size, the more accurate for range bounds calculation and the more + * balanced range distribution. + * + *

    Here are the heuristic rules + *

  • Target size: numPartitions x 100 to achieve good accuracy and is easier to calculate the + * range bounds + *
  • Min is 10K to achieve good accuracy while memory footprint is still relatively small + *
  • Max is 1M to cap the memory footprint on coordinator + * + * @param numPartitions number of range partitions which equals to downstream operator parallelism + * @return reservoir size + */ + static int determineCoordinatorReservoirSize(int numPartitions) { + int reservoirSize = numPartitions * COORDINATOR_TARGET_PARTITIONS_MULTIPLIER; + + if (reservoirSize < COORDINATOR_MIN_RESERVOIR_SIZE) { + // adjust it up and still make reservoirSize divisible by numPartitions + int remainder = COORDINATOR_MIN_RESERVOIR_SIZE % numPartitions; + reservoirSize = COORDINATOR_MIN_RESERVOIR_SIZE + (numPartitions - remainder); + } else if (reservoirSize > COORDINATOR_MAX_RESERVOIR_SIZE) { + // adjust it down and still make reservoirSize divisible by numPartitions + int remainder = COORDINATOR_MAX_RESERVOIR_SIZE % numPartitions; + reservoirSize = COORDINATOR_MAX_RESERVOIR_SIZE - remainder; + } + + return reservoirSize; + } + + /** + * Determine the sampling reservoir size where operator subtasks collect data statistics. + * + *

    Here are the heuristic rules + *

  • Target size is "coordinator reservoir size * over sampling ration (10) / operator + * parallelism" + *
  • Min is 1K to achieve good accuracy while memory footprint is still relatively small + *
  • Max is 100K to cap the memory footprint on coordinator + * + * @param numPartitions number of range partitions which equals to downstream operator parallelism + * @param operatorParallelism data statistics operator parallelism + * @return reservoir size + */ + static int determineOperatorReservoirSize(int operatorParallelism, int numPartitions) { + int coordinatorReservoirSize = determineCoordinatorReservoirSize(numPartitions); + int totalOperatorSamples = coordinatorReservoirSize * OPERATOR_OVER_SAMPLE_RATIO; + return (int) Math.ceil((double) totalOperatorSamples / operatorParallelism); + } + + /** + * To understand how range bounds are used in range partitioning, here is an example for human + * ages with 4 partitions: [15, 32, 60]. The 4 ranges would be + * + *
      + *
    • age <= 15 + *
    • age > 15 && age <= 32 + *
    • age >32 && age <= 60 + *
    • age > 60 + *
    + * + *

    Assumption is that a single key is not dominant enough to span multiple subtasks. + * + * @param numPartitions number of partitions which maps to downstream operator parallelism + * @param samples sampled keys + * @return array of range partition bounds. It should be a sorted list (ascending). Number of + * items should be {@code numPartitions - 1}. if numPartitions is 1, return an empty list + */ + static SortKey[] rangeBounds( + int numPartitions, Comparator comparator, SortKey[] samples) { + // sort the keys first + Arrays.sort(samples, comparator); + int numCandidates = numPartitions - 1; + SortKey[] candidates = new SortKey[numCandidates]; + int step = (int) Math.ceil((double) samples.length / numPartitions); + int position = step - 1; + int numChosen = 0; + while (position < samples.length && numChosen < numCandidates) { + SortKey candidate = samples[position]; + // skip duplicate values + if (numChosen > 0 && candidate.equals(candidates[numChosen - 1])) { + // linear probe for the next distinct value + position += 1; + } else { + candidates[numChosen] = candidate; + position += step; + numChosen += 1; + } + } + + return candidates; + } + + /** This can be a bit expensive since it is quadratic. */ + static void convertMapToSketch( + Map taskMapStats, Consumer sketchConsumer) { + taskMapStats.forEach( + (sortKey, count) -> { + for (int i = 0; i < count; ++i) { + sketchConsumer.accept(sortKey); + } + }); + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java index d03409f2a430..4ddc5a32d6bf 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java @@ -276,13 +276,12 @@ public static class SortKeySerializerSnapshot implements TypeSerializerSnapshot< private Schema schema; private SortOrder sortOrder; - @SuppressWarnings({"checkstyle:RedundantModifier", "WeakerAccess"}) + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) public SortKeySerializerSnapshot() { // this constructor is used when restoring from a checkpoint. } - // constructors need to public. Otherwise, Flink state restore would complain - // "The class has no (implicit) public nullary constructor". @SuppressWarnings("checkstyle:RedundantModifier") public SortKeySerializerSnapshot(Schema schema, SortOrder sortOrder) { this.schema = schema; @@ -320,8 +319,12 @@ public TypeSerializerSchemaCompatibility resolveSchemaCompatibility( return TypeSerializerSchemaCompatibility.incompatible(); } - SortKeySerializer newAvroSerializer = (SortKeySerializer) newSerializer; - return resolveSchemaCompatibility(newAvroSerializer.schema, schema); + SortKeySerializer sortKeySerializer = (SortKeySerializer) newSerializer; + if (!sortOrder.sameOrder(sortKeySerializer.sortOrder)) { + return TypeSerializerSchemaCompatibility.incompatible(); + } + + return resolveSchemaCompatibility(sortKeySerializer.schema, schema); } @Override diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySketchSerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySketchSerializer.java new file mode 100644 index 000000000000..d6c23f035015 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySketchSerializer.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.io.Serializable; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.List; +import org.apache.datasketches.common.ArrayOfItemsSerDe; +import org.apache.datasketches.common.ArrayOfStringsSerDe; +import org.apache.datasketches.common.ByteArrayUtil; +import org.apache.datasketches.common.Util; +import org.apache.datasketches.memory.Memory; +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; + +/** + * Only way to implement {@link ReservoirItemsSketch} serializer is to extend from {@link + * ArrayOfItemsSerDe}, as deserialization uses a private constructor from ReservoirItemsSketch. The + * implementation is modeled after {@link ArrayOfStringsSerDe} + */ +class SortKeySketchSerializer extends ArrayOfItemsSerDe implements Serializable { + private static final int DEFAULT_SORT_KEY_SIZE = 128; + + private final TypeSerializer itemSerializer; + private final ListSerializer listSerializer; + private final DataInputDeserializer input; + + SortKeySketchSerializer(TypeSerializer itemSerializer) { + this.itemSerializer = itemSerializer; + this.listSerializer = new ListSerializer<>(itemSerializer); + this.input = new DataInputDeserializer(); + } + + @Override + public byte[] serializeToByteArray(SortKey item) { + try { + DataOutputSerializer output = new DataOutputSerializer(DEFAULT_SORT_KEY_SIZE); + itemSerializer.serialize(item, output); + byte[] itemBytes = output.getSharedBuffer(); + int numBytes = output.length(); + byte[] out = new byte[numBytes + Integer.BYTES]; + ByteArrayUtil.copyBytes(itemBytes, 0, out, 4, numBytes); + ByteArrayUtil.putIntLE(out, 0, numBytes); + return out; + } catch (IOException e) { + throw new UncheckedIOException("Failed to serialize sort key", e); + } + } + + @Override + public byte[] serializeToByteArray(SortKey[] items) { + try { + DataOutputSerializer output = new DataOutputSerializer(DEFAULT_SORT_KEY_SIZE * items.length); + listSerializer.serialize(Arrays.asList(items), output); + byte[] itemsBytes = output.getSharedBuffer(); + int numBytes = output.length(); + byte[] out = new byte[Integer.BYTES + numBytes]; + ByteArrayUtil.putIntLE(out, 0, numBytes); + System.arraycopy(itemsBytes, 0, out, Integer.BYTES, numBytes); + return out; + } catch (IOException e) { + throw new UncheckedIOException("Failed to serialize sort key", e); + } + } + + @Override + public SortKey[] deserializeFromMemory(Memory mem, long startingOffset, int numItems) { + Preconditions.checkArgument(mem != null, "Invalid input memory: null"); + if (numItems <= 0) { + return new SortKey[0]; + } + + long offset = startingOffset; + Util.checkBounds(offset, Integer.BYTES, mem.getCapacity()); + int numBytes = mem.getInt(offset); + offset += Integer.BYTES; + + Util.checkBounds(offset, numBytes, mem.getCapacity()); + byte[] sortKeyBytes = new byte[numBytes]; + mem.getByteArray(offset, sortKeyBytes, 0, numBytes); + input.setBuffer(sortKeyBytes); + + try { + List sortKeys = listSerializer.deserialize(input); + SortKey[] array = new SortKey[numItems]; + sortKeys.toArray(array); + input.releaseArrays(); + return array; + } catch (IOException e) { + throw new UncheckedIOException("Failed to deserialize sort key sketch", e); + } + } + + @Override + public int sizeOf(SortKey item) { + return serializeToByteArray(item).length; + } + + @Override + public int sizeOf(Memory mem, long offset, int numItems) { + Preconditions.checkArgument(mem != null, "Invalid input memory: null"); + if (numItems <= 0) { + return 0; + } + + Util.checkBounds(offset, Integer.BYTES, mem.getCapacity()); + int numBytes = mem.getInt(offset); + return Integer.BYTES + numBytes; + } + + @Override + public String toString(SortKey item) { + return item.toString(); + } + + @Override + public Class getClassOfT() { + return SortKey.class; + } +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsEvent.java similarity index 58% rename from flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java rename to flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsEvent.java index 852d2157b8cb..f6fcdb8b16ef 100644 --- a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsEvent.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsEvent.java @@ -27,24 +27,39 @@ * statistics in bytes */ @Internal -class DataStatisticsEvent, S> implements OperatorEvent { +class StatisticsEvent implements OperatorEvent { private static final long serialVersionUID = 1L; private final long checkpointId; private final byte[] statisticsBytes; + private final boolean applyImmediately; - private DataStatisticsEvent(long checkpointId, byte[] statisticsBytes) { + private StatisticsEvent(long checkpointId, byte[] statisticsBytes, boolean applyImmediately) { this.checkpointId = checkpointId; this.statisticsBytes = statisticsBytes; + this.applyImmediately = applyImmediately; } - static , S> DataStatisticsEvent create( + static StatisticsEvent createTaskStatisticsEvent( long checkpointId, - DataStatistics dataStatistics, - TypeSerializer> statisticsSerializer) { - return new DataStatisticsEvent<>( + DataStatistics statistics, + TypeSerializer statisticsSerializer) { + // applyImmediately is really only relevant for coordinator to operator event. + // task reported statistics is always merged immediately by the coordinator. + return new StatisticsEvent( checkpointId, - DataStatisticsUtil.serializeDataStatistics(dataStatistics, statisticsSerializer)); + StatisticsUtil.serializeDataStatistics(statistics, statisticsSerializer), + true); + } + + static StatisticsEvent createGlobalStatisticsEvent( + GlobalStatistics statistics, + TypeSerializer statisticsSerializer, + boolean applyImmediately) { + return new StatisticsEvent( + statistics.checkpointId(), + StatisticsUtil.serializeGlobalStatistics(statistics, statisticsSerializer), + applyImmediately); } long checkpointId() { @@ -54,4 +69,8 @@ long checkpointId() { byte[] statisticsBytes() { return statisticsBytes; } + + boolean applyImmediately() { + return applyImmediately; + } } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecord.java similarity index 66% rename from flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java rename to flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecord.java index 889e85112e16..bc28df2b0e22 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecord.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecord.java @@ -19,6 +19,7 @@ package org.apache.iceberg.flink.sink.shuffle; import java.io.Serializable; +import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.table.data.RowData; import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; @@ -34,68 +35,66 @@ * After shuffling, a filter and mapper are required to filter out the data distribution weight, * unwrap the object and extract the original record type T. */ -class DataStatisticsOrRecord, S> implements Serializable { +@Internal +public class StatisticsOrRecord implements Serializable { private static final long serialVersionUID = 1L; - private DataStatistics statistics; + private GlobalStatistics statistics; private RowData record; - private DataStatisticsOrRecord(DataStatistics statistics, RowData record) { + private StatisticsOrRecord(GlobalStatistics statistics, RowData record) { Preconditions.checkArgument( record != null ^ statistics != null, "DataStatistics or record, not neither or both"); this.statistics = statistics; this.record = record; } - static , S> DataStatisticsOrRecord fromRecord( - RowData record) { - return new DataStatisticsOrRecord<>(null, record); + static StatisticsOrRecord fromRecord(RowData record) { + return new StatisticsOrRecord(null, record); } - static , S> DataStatisticsOrRecord fromDataStatistics( - DataStatistics statistics) { - return new DataStatisticsOrRecord<>(statistics, null); + static StatisticsOrRecord fromStatistics(GlobalStatistics statistics) { + return new StatisticsOrRecord(statistics, null); } - static , S> DataStatisticsOrRecord reuseRecord( - DataStatisticsOrRecord reuse, TypeSerializer recordSerializer) { + static StatisticsOrRecord reuseRecord( + StatisticsOrRecord reuse, TypeSerializer recordSerializer) { if (reuse.hasRecord()) { return reuse; } else { // not reusable - return DataStatisticsOrRecord.fromRecord(recordSerializer.createInstance()); + return StatisticsOrRecord.fromRecord(recordSerializer.createInstance()); } } - static , S> DataStatisticsOrRecord reuseStatistics( - DataStatisticsOrRecord reuse, - TypeSerializer> statisticsSerializer) { - if (reuse.hasDataStatistics()) { + static StatisticsOrRecord reuseStatistics( + StatisticsOrRecord reuse, TypeSerializer statisticsSerializer) { + if (reuse.hasStatistics()) { return reuse; } else { // not reusable - return DataStatisticsOrRecord.fromDataStatistics(statisticsSerializer.createInstance()); + return StatisticsOrRecord.fromStatistics(statisticsSerializer.createInstance()); } } - boolean hasDataStatistics() { + boolean hasStatistics() { return statistics != null; } - boolean hasRecord() { + public boolean hasRecord() { return record != null; } - DataStatistics dataStatistics() { + GlobalStatistics statistics() { return statistics; } - void dataStatistics(DataStatistics newStatistics) { + void statistics(GlobalStatistics newStatistics) { this.statistics = newStatistics; } - RowData record() { + public RowData record() { return record; } diff --git a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecordSerializer.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecordSerializer.java similarity index 53% rename from flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecordSerializer.java rename to flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecordSerializer.java index e9a6fa0cbfc5..6e403425938d 100644 --- a/flink/v1.17/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOrRecordSerializer.java +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsOrRecordSerializer.java @@ -29,13 +29,12 @@ import org.apache.flink.table.data.RowData; @Internal -class DataStatisticsOrRecordSerializer, S> - extends TypeSerializer> { - private final TypeSerializer> statisticsSerializer; +class StatisticsOrRecordSerializer extends TypeSerializer { + private final TypeSerializer statisticsSerializer; private final TypeSerializer recordSerializer; - DataStatisticsOrRecordSerializer( - TypeSerializer> statisticsSerializer, + StatisticsOrRecordSerializer( + TypeSerializer statisticsSerializer, TypeSerializer recordSerializer) { this.statisticsSerializer = statisticsSerializer; this.recordSerializer = recordSerializer; @@ -48,13 +47,13 @@ public boolean isImmutableType() { @SuppressWarnings("ReferenceEquality") @Override - public TypeSerializer> duplicate() { - TypeSerializer> duplicateStatisticsSerializer = + public TypeSerializer duplicate() { + TypeSerializer duplicateStatisticsSerializer = statisticsSerializer.duplicate(); TypeSerializer duplicateRowDataSerializer = recordSerializer.duplicate(); if ((statisticsSerializer != duplicateStatisticsSerializer) || (recordSerializer != duplicateRowDataSerializer)) { - return new DataStatisticsOrRecordSerializer<>( + return new StatisticsOrRecordSerializer( duplicateStatisticsSerializer, duplicateRowDataSerializer); } else { return this; @@ -62,34 +61,31 @@ public TypeSerializer> duplicate() { } @Override - public DataStatisticsOrRecord createInstance() { + public StatisticsOrRecord createInstance() { // arbitrarily always create RowData value instance - return DataStatisticsOrRecord.fromRecord(recordSerializer.createInstance()); + return StatisticsOrRecord.fromRecord(recordSerializer.createInstance()); } @Override - public DataStatisticsOrRecord copy(DataStatisticsOrRecord from) { + public StatisticsOrRecord copy(StatisticsOrRecord from) { if (from.hasRecord()) { - return DataStatisticsOrRecord.fromRecord(recordSerializer.copy(from.record())); + return StatisticsOrRecord.fromRecord(recordSerializer.copy(from.record())); } else { - return DataStatisticsOrRecord.fromDataStatistics( - statisticsSerializer.copy(from.dataStatistics())); + return StatisticsOrRecord.fromStatistics(statisticsSerializer.copy(from.statistics())); } } @Override - public DataStatisticsOrRecord copy( - DataStatisticsOrRecord from, DataStatisticsOrRecord reuse) { - DataStatisticsOrRecord to; + public StatisticsOrRecord copy(StatisticsOrRecord from, StatisticsOrRecord reuse) { + StatisticsOrRecord to; if (from.hasRecord()) { - to = DataStatisticsOrRecord.reuseRecord(reuse, recordSerializer); + to = StatisticsOrRecord.reuseRecord(reuse, recordSerializer); RowData record = recordSerializer.copy(from.record(), to.record()); to.record(record); } else { - to = DataStatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); - DataStatistics statistics = - statisticsSerializer.copy(from.dataStatistics(), to.dataStatistics()); - to.dataStatistics(statistics); + to = StatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); + GlobalStatistics statistics = statisticsSerializer.copy(from.statistics(), to.statistics()); + to.statistics(statistics); } return to; @@ -101,41 +97,40 @@ public int getLength() { } @Override - public void serialize(DataStatisticsOrRecord statisticsOrRecord, DataOutputView target) + public void serialize(StatisticsOrRecord statisticsOrRecord, DataOutputView target) throws IOException { if (statisticsOrRecord.hasRecord()) { target.writeBoolean(true); recordSerializer.serialize(statisticsOrRecord.record(), target); } else { target.writeBoolean(false); - statisticsSerializer.serialize(statisticsOrRecord.dataStatistics(), target); + statisticsSerializer.serialize(statisticsOrRecord.statistics(), target); } } @Override - public DataStatisticsOrRecord deserialize(DataInputView source) throws IOException { + public StatisticsOrRecord deserialize(DataInputView source) throws IOException { boolean isRecord = source.readBoolean(); if (isRecord) { - return DataStatisticsOrRecord.fromRecord(recordSerializer.deserialize(source)); + return StatisticsOrRecord.fromRecord(recordSerializer.deserialize(source)); } else { - return DataStatisticsOrRecord.fromDataStatistics(statisticsSerializer.deserialize(source)); + return StatisticsOrRecord.fromStatistics(statisticsSerializer.deserialize(source)); } } @Override - public DataStatisticsOrRecord deserialize( - DataStatisticsOrRecord reuse, DataInputView source) throws IOException { - DataStatisticsOrRecord to; + public StatisticsOrRecord deserialize(StatisticsOrRecord reuse, DataInputView source) + throws IOException { + StatisticsOrRecord to; boolean isRecord = source.readBoolean(); if (isRecord) { - to = DataStatisticsOrRecord.reuseRecord(reuse, recordSerializer); + to = StatisticsOrRecord.reuseRecord(reuse, recordSerializer); RowData record = recordSerializer.deserialize(to.record(), source); to.record(record); } else { - to = DataStatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); - DataStatistics statistics = - statisticsSerializer.deserialize(to.dataStatistics(), source); - to.dataStatistics(statistics); + to = StatisticsOrRecord.reuseStatistics(reuse, statisticsSerializer); + GlobalStatistics statistics = statisticsSerializer.deserialize(to.statistics(), source); + to.statistics(statistics); } return to; @@ -154,12 +149,11 @@ public void copy(DataInputView source, DataOutputView target) throws IOException @Override public boolean equals(Object obj) { - if (!(obj instanceof DataStatisticsOrRecordSerializer)) { + if (!(obj instanceof StatisticsOrRecordSerializer)) { return false; } - @SuppressWarnings("unchecked") - DataStatisticsOrRecordSerializer other = (DataStatisticsOrRecordSerializer) obj; + StatisticsOrRecordSerializer other = (StatisticsOrRecordSerializer) obj; return Objects.equals(statisticsSerializer, other.statisticsSerializer) && Objects.equals(recordSerializer, other.recordSerializer); } @@ -170,25 +164,22 @@ public int hashCode() { } @Override - public TypeSerializerSnapshot> snapshotConfiguration() { - return new DataStatisticsOrRecordSerializerSnapshot<>(this); + public TypeSerializerSnapshot snapshotConfiguration() { + return new StatisticsOrRecordSerializerSnapshot(this); } - public static class DataStatisticsOrRecordSerializerSnapshot, S> - extends CompositeTypeSerializerSnapshot< - DataStatisticsOrRecord, DataStatisticsOrRecordSerializer> { + public static class StatisticsOrRecordSerializerSnapshot + extends CompositeTypeSerializerSnapshot { private static final int CURRENT_VERSION = 1; - // constructors need to public. Otherwise, Flink state restore would complain - // "The class has no (implicit) public nullary constructor". - @SuppressWarnings("checkstyle:RedundantModifier") - public DataStatisticsOrRecordSerializerSnapshot() { - super(DataStatisticsOrRecordSerializer.class); + /** Constructor for read instantiation. */ + @SuppressWarnings({"unused", "checkstyle:RedundantModifier"}) + public StatisticsOrRecordSerializerSnapshot() { + super(StatisticsOrRecordSerializer.class); } @SuppressWarnings("checkstyle:RedundantModifier") - public DataStatisticsOrRecordSerializerSnapshot( - DataStatisticsOrRecordSerializer serializer) { + public StatisticsOrRecordSerializerSnapshot(StatisticsOrRecordSerializer serializer) { super(serializer); } @@ -200,7 +191,7 @@ protected int getCurrentOuterSnapshotVersion() { @Override protected TypeSerializer[] getNestedSerializers( - DataStatisticsOrRecordSerializer outerSerializer) { + StatisticsOrRecordSerializer outerSerializer) { return new TypeSerializer[] { outerSerializer.statisticsSerializer, outerSerializer.recordSerializer }; @@ -208,12 +199,12 @@ protected TypeSerializer[] getNestedSerializers( @SuppressWarnings("unchecked") @Override - protected DataStatisticsOrRecordSerializer createOuterSerializerWithNestedSerializers( + protected StatisticsOrRecordSerializer createOuterSerializerWithNestedSerializers( TypeSerializer[] nestedSerializers) { - TypeSerializer> statisticsSerializer = - (TypeSerializer>) nestedSerializers[0]; + TypeSerializer statisticsSerializer = + (TypeSerializer) nestedSerializers[0]; TypeSerializer recordSerializer = (TypeSerializer) nestedSerializers[1]; - return new DataStatisticsOrRecordSerializer<>(statisticsSerializer, recordSerializer); + return new StatisticsOrRecordSerializer(statisticsSerializer, recordSerializer); } } } diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsType.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsType.java new file mode 100644 index 000000000000..43f72e336e06 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsType.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +/** + * Range distribution requires gathering statistics on the sort keys to determine proper range + * boundaries to distribute/cluster rows before writer operators. + */ +public enum StatisticsType { + /** + * Tracks the data statistics as {@code Map} frequency. It works better for + * low-cardinality scenarios (like country, event_type, etc.) where the cardinalities are in + * hundreds or thousands. + * + *

      + *
    • Pro: accurate measurement on the statistics/weight of every key. + *
    • Con: memory footprint can be large if the key cardinality is high. + *
    + */ + Map, + + /** + * Sample the sort keys via reservoir sampling. Then split the range partitions via range bounds + * from sampled values. It works better for high-cardinality scenarios (like device_id, user_id, + * uuid etc.) where the cardinalities can be in millions or billions. + * + *
      + *
    • Pro: relatively low memory footprint for high-cardinality sort keys. + *
    • Con: non-precise approximation with potentially lower accuracy. + *
    + */ + Sketch, + + /** + * Initially use Map for statistics tracking. If key cardinality turns out to be high, + * automatically switch to sketch sampling. + */ + Auto +} diff --git a/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java new file mode 100644 index 000000000000..5d48ec57ca49 --- /dev/null +++ b/flink/v1.18/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/StatisticsUtil.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.io.IOException; +import java.io.UncheckedIOException; +import javax.annotation.Nullable; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; + +class StatisticsUtil { + + private StatisticsUtil() {} + + static DataStatistics createTaskStatistics( + StatisticsType type, int operatorParallelism, int numPartitions) { + if (type == StatisticsType.Map) { + return new MapDataStatistics(); + } else { + return new SketchDataStatistics( + SketchUtil.determineOperatorReservoirSize(operatorParallelism, numPartitions)); + } + } + + static byte[] serializeDataStatistics( + DataStatistics dataStatistics, TypeSerializer statisticsSerializer) { + DataOutputSerializer out = new DataOutputSerializer(64); + try { + statisticsSerializer.serialize(dataStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new UncheckedIOException("Fail to serialize data statistics", e); + } + } + + static DataStatistics deserializeDataStatistics( + byte[] bytes, TypeSerializer statisticsSerializer) { + DataInputDeserializer input = new DataInputDeserializer(bytes, 0, bytes.length); + try { + return statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new UncheckedIOException("Fail to deserialize data statistics", e); + } + } + + static byte[] serializeCompletedStatistics( + CompletedStatistics completedStatistics, + TypeSerializer statisticsSerializer) { + try { + DataOutputSerializer out = new DataOutputSerializer(1024); + statisticsSerializer.serialize(completedStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new UncheckedIOException("Fail to serialize aggregated statistics", e); + } + } + + static CompletedStatistics deserializeCompletedStatistics( + byte[] bytes, TypeSerializer statisticsSerializer) { + try { + DataInputDeserializer input = new DataInputDeserializer(bytes); + return statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new UncheckedIOException("Fail to deserialize aggregated statistics", e); + } + } + + static byte[] serializeGlobalStatistics( + GlobalStatistics globalStatistics, TypeSerializer statisticsSerializer) { + try { + DataOutputSerializer out = new DataOutputSerializer(1024); + statisticsSerializer.serialize(globalStatistics, out); + return out.getCopyOfBuffer(); + } catch (IOException e) { + throw new UncheckedIOException("Fail to serialize aggregated statistics", e); + } + } + + static GlobalStatistics deserializeGlobalStatistics( + byte[] bytes, TypeSerializer statisticsSerializer) { + try { + DataInputDeserializer input = new DataInputDeserializer(bytes); + return statisticsSerializer.deserialize(input); + } catch (IOException e) { + throw new UncheckedIOException("Fail to deserialize aggregated statistics", e); + } + } + + static StatisticsType collectType(StatisticsType config) { + return config == StatisticsType.Sketch ? StatisticsType.Sketch : StatisticsType.Map; + } + + static StatisticsType collectType(StatisticsType config, @Nullable GlobalStatistics statistics) { + if (statistics != null) { + return statistics.type(); + } + + return collectType(config); + } + + static StatisticsType collectType( + StatisticsType config, @Nullable CompletedStatistics statistics) { + if (statistics != null) { + return statistics.type(); + } + + return collectType(config); + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/Fixtures.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/Fixtures.java new file mode 100644 index 000000000000..5910bd685510 --- /dev/null +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/Fixtures.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import java.util.Comparator; +import java.util.Map; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.RowDataSerializer; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.iceberg.Schema; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.SortOrder; +import org.apache.iceberg.SortOrderComparators; +import org.apache.iceberg.StructLike; +import org.apache.iceberg.flink.RowDataWrapper; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; + +class Fixtures { + private Fixtures() {} + + public static final int NUM_SUBTASKS = 2; + public static final Schema SCHEMA = + new Schema( + Types.NestedField.optional(1, "id", Types.StringType.get()), + Types.NestedField.optional(2, "number", Types.IntegerType.get())); + public static final RowType ROW_TYPE = RowType.of(new VarCharType(), new IntType()); + public static final TypeSerializer ROW_SERIALIZER = new RowDataSerializer(ROW_TYPE); + public static final RowDataWrapper ROW_WRAPPER = new RowDataWrapper(ROW_TYPE, SCHEMA.asStruct()); + public static final SortOrder SORT_ORDER = SortOrder.builderFor(SCHEMA).asc("id").build(); + public static final Comparator SORT_ORDER_COMPARTOR = + SortOrderComparators.forSchema(SCHEMA, SORT_ORDER); + public static final SortKeySerializer SORT_KEY_SERIALIZER = + new SortKeySerializer(SCHEMA, SORT_ORDER); + public static final DataStatisticsSerializer TASK_STATISTICS_SERIALIZER = + new DataStatisticsSerializer(SORT_KEY_SERIALIZER); + public static final GlobalStatisticsSerializer GLOBAL_STATISTICS_SERIALIZER = + new GlobalStatisticsSerializer(SORT_KEY_SERIALIZER); + public static final CompletedStatisticsSerializer COMPLETED_STATISTICS_SERIALIZER = + new CompletedStatisticsSerializer(SORT_KEY_SERIALIZER); + + public static final SortKey SORT_KEY = new SortKey(SCHEMA, SORT_ORDER); + public static final Map CHAR_KEYS = createCharKeys(); + + public static StatisticsEvent createStatisticsEvent( + StatisticsType type, + TypeSerializer statisticsSerializer, + long checkpointId, + SortKey... keys) { + DataStatistics statistics = createTaskStatistics(type, keys); + return StatisticsEvent.createTaskStatisticsEvent( + checkpointId, statistics, statisticsSerializer); + } + + public static DataStatistics createTaskStatistics(StatisticsType type, SortKey... keys) { + DataStatistics statistics; + if (type == StatisticsType.Sketch) { + statistics = new SketchDataStatistics(128); + } else { + statistics = new MapDataStatistics(); + } + + for (SortKey key : keys) { + statistics.add(key); + } + + return statistics; + } + + private static Map createCharKeys() { + Map keys = Maps.newHashMap(); + for (char c = 'a'; c <= 'z'; ++c) { + String key = Character.toString(c); + SortKey sortKey = SORT_KEY.copy(); + sortKey.set(0, key); + keys.put(key, sortKey); + } + + return keys; + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java deleted file mode 100644 index 739cf764e2a6..000000000000 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.iceberg.flink.sink.shuffle; - -import static org.assertj.core.api.Assertions.assertThat; - -import java.util.Map; -import org.apache.iceberg.Schema; -import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.types.Types; -import org.junit.jupiter.api.Test; - -public class TestAggregatedStatistics { - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); - - @Test - public void mergeDataStatisticTest() { - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - - AggregatedStatistics> aggregatedStatistics = - new AggregatedStatistics<>(1, statisticsSerializer); - MapDataStatistics mapDataStatistics1 = new MapDataStatistics(); - mapDataStatistics1.add(keyA); - mapDataStatistics1.add(keyA); - mapDataStatistics1.add(keyB); - aggregatedStatistics.mergeDataStatistic("testOperator", 1, mapDataStatistics1); - MapDataStatistics mapDataStatistics2 = new MapDataStatistics(); - mapDataStatistics2.add(keyA); - aggregatedStatistics.mergeDataStatistic("testOperator", 1, mapDataStatistics2); - assertThat(aggregatedStatistics.dataStatistics().statistics().get(keyA)) - .isEqualTo( - mapDataStatistics1.statistics().get(keyA) + mapDataStatistics2.statistics().get(keyA)); - assertThat(aggregatedStatistics.dataStatistics().statistics().get(keyB)) - .isEqualTo( - mapDataStatistics1.statistics().get(keyB) - + mapDataStatistics2.statistics().getOrDefault(keyB, 0L)); - } -} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java index 0064c91340bf..8322ce683768 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java @@ -18,161 +18,448 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.TASK_STATISTICS_SERIALIZER; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.createStatisticsEvent; import static org.assertj.core.api.Assertions.assertThat; -import java.util.Map; -import org.apache.iceberg.Schema; -import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.types.Types; -import org.junit.jupiter.api.BeforeEach; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class TestAggregatedStatisticsTracker { - private static final int NUM_SUBTASKS = 2; - - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); - private final SortKey keyA = sortKey.copy(); - private final SortKey keyB = sortKey.copy(); - - private AggregatedStatisticsTracker> - aggregatedStatisticsTracker; - - public TestAggregatedStatisticsTracker() { - keyA.set(0, "a"); - keyB.set(0, "b"); - } + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void receiveNewerStatisticsEvent(StatisticsType type) { + AggregatedStatisticsTracker tracker = createTracker(type); - @BeforeEach - public void before() throws Exception { - aggregatedStatisticsTracker = - new AggregatedStatisticsTracker<>("testOperator", statisticsSerializer, NUM_SUBTASKS); - } + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("a")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a")); + } - @Test - public void receiveNewerDataStatisticEvent() { - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - checkpoint1Subtask0DataStatistic.add(keyA); - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint1Subtask0DataStatisticEvent)) - .isNull(); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(1); - - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyA); - DataStatisticsEvent> - checkpoint2Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint2Subtask0DataStatisticEvent)) - .isNull(); - // Checkpoint 2 is newer than checkpoint1, thus dropping in progress statistics for checkpoint1 - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2); + StatisticsEvent checkpoint2Subtask0StatisticsEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 2L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(0, checkpoint2Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + // both checkpoints are tracked + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L, 2L); + aggregation = tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + // checkpoint 1 is completed + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + assertThat(completedStatistics.checkpointId()).isEqualTo(1L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 1L, + CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + + // checkpoint 2 remains + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(2L); + aggregation = tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } } - @Test - public void receiveOlderDataStatisticEventTest() { - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyA); - checkpoint2Subtask0DataStatistic.add(keyB); - checkpoint2Subtask0DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint3Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint3Subtask0DataStatisticEvent)) - .isNull(); - - MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); - checkpoint1Subtask1DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint1Subtask1DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); - // Receive event from old checkpoint, aggregatedStatisticsAggregatorTracker won't return - // completed statistics and in progress statistics won't be updated - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 1, checkpoint1Subtask1DataStatisticEvent)) - .isNull(); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(2); + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void receiveOlderStatisticsEventTest(StatisticsType type) { + AggregatedStatisticsTracker tracker = createTracker(type); + + StatisticsEvent checkpoint2Subtask0StatisticsEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 2L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint2Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(2L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + assertThat(completedStatistics).isNull(); + // both checkpoints are tracked + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L, 2L); + aggregation = tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint3Subtask0StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 3L, CHAR_KEYS.get("x")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint3Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L, 2L, 3L); + aggregation = tracker.aggregationsPerCheckpoint().get(3L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("x"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("x")); + } + + StatisticsEvent checkpoint2Subtask1StatisticsEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 2L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint2Subtask1StatisticsEvent); + // checkpoint 1 is cleared along with checkpoint 2. checkpoint 3 remains + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(3L); + aggregation = tracker.aggregationsPerCheckpoint().get(3L); + assertThat(aggregation.currentType()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("x"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("x")); + } + + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + assertThat(completedStatistics.checkpointId()).isEqualTo(2L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 2L, + CHAR_KEYS.get("b"), 4L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + } } - @Test - public void receiveCompletedDataStatisticEvent() { - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - checkpoint1Subtask0DataStatistic.add(keyA); - checkpoint1Subtask0DataStatistic.add(keyB); - checkpoint1Subtask0DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint1Subtask0DataStatisticEvent)) - .isNull(); - - MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); - checkpoint1Subtask1DataStatistic.add(keyA); - checkpoint1Subtask1DataStatistic.add(keyA); - checkpoint1Subtask1DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint1Subtask1DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void receiveCompletedStatisticsEvent(StatisticsType type) { + AggregatedStatisticsTracker tracker = createTracker(type); + + StatisticsEvent checkpoint1Subtask0DataStatisticEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b")); + + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0DataStatisticEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint1Subtask1DataStatisticEvent = + createStatisticsEvent( + type, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + // Receive data statistics from all subtasks at checkpoint 1 - AggregatedStatistics> completedStatistics = - aggregatedStatisticsTracker.updateAndCheckCompletion( - 1, checkpoint1Subtask1DataStatisticEvent); + completedStatistics = + tracker.updateAndCheckCompletion(1, checkpoint1Subtask1DataStatisticEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); assertThat(completedStatistics).isNotNull(); - assertThat(completedStatistics.checkpointId()).isEqualTo(1); - MapDataStatistics globalDataStatistics = - (MapDataStatistics) completedStatistics.dataStatistics(); - assertThat((long) globalDataStatistics.statistics().get(keyA)) - .isEqualTo( - checkpoint1Subtask0DataStatistic.statistics().get(keyA) - + checkpoint1Subtask1DataStatistic.statistics().get(keyA)); - assertThat((long) globalDataStatistics.statistics().get(keyB)) - .isEqualTo( - checkpoint1Subtask0DataStatistic.statistics().get(keyB) - + checkpoint1Subtask1DataStatistic.statistics().get(keyB)); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()) - .isEqualTo(completedStatistics.checkpointId() + 1); - - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyA); - DataStatisticsEvent> - checkpoint2Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - assertThat( - aggregatedStatisticsTracker.updateAndCheckCompletion( - 0, checkpoint2Subtask0DataStatisticEvent)) - .isNull(); - assertThat(completedStatistics.checkpointId()).isEqualTo(1); - - MapDataStatistics checkpoint2Subtask1DataStatistic = new MapDataStatistics(); - checkpoint2Subtask1DataStatistic.add(keyB); - DataStatisticsEvent> - checkpoint2Subtask1DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask1DataStatistic, statisticsSerializer); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + assertThat(completedStatistics.checkpointId()).isEqualTo(1L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 3L, + CHAR_KEYS.get("b"), 3L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + } + + StatisticsEvent checkpoint2Subtask0DataStatisticEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("a")); + completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint2Subtask0DataStatisticEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(2L); + aggregation = tracker.aggregationsPerCheckpoint().get(2L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(aggregation.mapStatistics()).isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L)); + } else { + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder(CHAR_KEYS.get("a")); + } + + StatisticsEvent checkpoint2Subtask1DataStatisticEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("b")); // Receive data statistics from all subtasks at checkpoint 2 completedStatistics = - aggregatedStatisticsTracker.updateAndCheckCompletion( - 1, checkpoint2Subtask1DataStatisticEvent); + tracker.updateAndCheckCompletion(1, checkpoint2Subtask1DataStatisticEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); + + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.checkpointId()).isEqualTo(2L); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo( + ImmutableMap.of( + CHAR_KEYS.get("a"), 1L, + CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + } + + @Test + public void coordinatorSwitchToSketchOverThreshold() { + int parallelism = 3; + int downstreamParallelism = 3; + int switchToSketchThreshold = 3; + AggregatedStatisticsTracker tracker = + new AggregatedStatisticsTracker( + "testOperator", + parallelism, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + downstreamParallelism, + StatisticsType.Auto, + switchToSketchThreshold, + null); + + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Map); + assertThat(aggregation.sketchStatistics()).isNull(); + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L)); + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + aggregation = tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0, 1); + // converted to sketch statistics as map size is 4 (over the switch threshold of 3) + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Sketch); + assertThat(aggregation.mapStatistics()).isNull(); + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder( + CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("c"), CHAR_KEYS.get("d")); + StatisticsEvent checkpoint1Subtask2StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + completedStatistics = tracker.updateAndCheckCompletion(2, checkpoint1Subtask2StatisticsEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); assertThat(completedStatistics).isNotNull(); - assertThat(completedStatistics.checkpointId()).isEqualTo(2); - assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()) - .isEqualTo(completedStatistics.checkpointId() + 1); + assertThat(completedStatistics.type()).isEqualTo(StatisticsType.Sketch); + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + } + + @Test + public void coordinatorMapOperatorSketch() { + int parallelism = 3; + int downstreamParallelism = 3; + AggregatedStatisticsTracker tracker = + new AggregatedStatisticsTracker( + "testOperator", + parallelism, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + downstreamParallelism, + StatisticsType.Auto, + SketchUtil.COORDINATOR_SKETCH_SWITCH_THRESHOLD, + null); + + // first operator event has map statistics + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b")); + CompletedStatistics completedStatistics = + tracker.updateAndCheckCompletion(0, checkpoint1Subtask0StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + AggregatedStatisticsTracker.Aggregation aggregation = + tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0); + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Map); + assertThat(aggregation.sketchStatistics()).isNull(); + assertThat(aggregation.mapStatistics()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L)); + + // second operator event contains sketch statistics + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent( + StatisticsType.Sketch, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d")); + completedStatistics = tracker.updateAndCheckCompletion(1, checkpoint1Subtask1StatisticsEvent); + assertThat(completedStatistics).isNull(); + assertThat(tracker.aggregationsPerCheckpoint().keySet()).containsExactlyInAnyOrder(1L); + aggregation = tracker.aggregationsPerCheckpoint().get(1L); + assertThat(aggregation.subtaskSet()).containsExactlyInAnyOrder(0, 1); + assertThat(aggregation.currentType()).isEqualTo(StatisticsType.Sketch); + assertThat(aggregation.mapStatistics()).isNull(); + assertThat(aggregation.sketchStatistics().getResult().getSamples()) + .containsExactlyInAnyOrder( + CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("c"), CHAR_KEYS.get("d")); + + // third operator event has Map statistics + StatisticsEvent checkpoint1Subtask2StatisticsEvent = + createStatisticsEvent( + StatisticsType.Map, + TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + completedStatistics = tracker.updateAndCheckCompletion(2, checkpoint1Subtask2StatisticsEvent); + assertThat(tracker.aggregationsPerCheckpoint()).isEmpty(); + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsType.Sketch); + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f")); + } + + private AggregatedStatisticsTracker createTracker(StatisticsType type) { + return new AggregatedStatisticsTracker( + "testOperator", + Fixtures.NUM_SUBTASKS, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + Fixtures.NUM_SUBTASKS, + type, + SketchUtil.COORDINATOR_SKETCH_SWITCH_THRESHOLD, + null); } } diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestCompletedStatisticsSerializer.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestCompletedStatisticsSerializer.java new file mode 100644 index 000000000000..4ee9888934a8 --- /dev/null +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestCompletedStatisticsSerializer.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestCompletedStatisticsSerializer extends SerializerTestBase { + + @Override + protected TypeSerializer createSerializer() { + return Fixtures.COMPLETED_STATISTICS_SERIALIZER; + } + + @Override + protected int getLength() { + return -1; + } + + @Override + protected Class getTypeClass() { + return CompletedStatistics.class; + } + + @Override + protected CompletedStatistics[] getTestData() { + + return new CompletedStatistics[] { + CompletedStatistics.fromKeyFrequency( + 1L, ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L)), + CompletedStatistics.fromKeySamples(2L, new SortKey[] {CHAR_KEYS.get("a"), CHAR_KEYS.get("b")}) + }; + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java index 849253564209..a08a8a73e80c 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java @@ -18,9 +18,13 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.NUM_SUBTASKS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import java.time.Duration; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -28,128 +32,182 @@ import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; import org.apache.flink.util.ExceptionUtils; -import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.types.Types; +import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class TestDataStatisticsCoordinator { private static final String OPERATOR_NAME = "TestCoordinator"; private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 5678L); - private static final int NUM_SUBTASKS = 2; - - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); private EventReceivingTasks receivingTasks; - private DataStatisticsCoordinator> - dataStatisticsCoordinator; @BeforeEach public void before() throws Exception { receivingTasks = EventReceivingTasks.createForRunningTasks(); - dataStatisticsCoordinator = - new DataStatisticsCoordinator<>( - OPERATOR_NAME, - new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), - statisticsSerializer); } - private void tasksReady() throws Exception { - dataStatisticsCoordinator.start(); - setAllTasksReady(NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); + private void tasksReady(DataStatisticsCoordinator coordinator) { + setAllTasksReady(NUM_SUBTASKS, coordinator, receivingTasks); } - @Test - public void testThrowExceptionWhenNotStarted() { - String failureMessage = "The coordinator of TestCoordinator has not started yet."; - - assertThatThrownBy( - () -> - dataStatisticsCoordinator.handleEventFromOperator( - 0, - 0, - DataStatisticsEvent.create(0, new MapDataStatistics(), statisticsSerializer))) - .isInstanceOf(IllegalStateException.class) - .hasMessage(failureMessage); - assertThatThrownBy(() -> dataStatisticsCoordinator.executionAttemptFailed(0, 0, null)) - .isInstanceOf(IllegalStateException.class) - .hasMessage(failureMessage); - assertThatThrownBy(() -> dataStatisticsCoordinator.checkpointCoordinator(0, null)) - .isInstanceOf(IllegalStateException.class) - .hasMessage(failureMessage); + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testThrowExceptionWhenNotStarted(StatisticsType type) throws Exception { + try (DataStatisticsCoordinator dataStatisticsCoordinator = createCoordinator(type)) { + String failureMessage = "The coordinator of TestCoordinator has not started yet."; + assertThatThrownBy( + () -> + dataStatisticsCoordinator.handleEventFromOperator( + 0, + 0, + StatisticsEvent.createTaskStatisticsEvent( + 0, new MapDataStatistics(), Fixtures.TASK_STATISTICS_SERIALIZER))) + .isInstanceOf(IllegalStateException.class) + .hasMessage(failureMessage); + assertThatThrownBy(() -> dataStatisticsCoordinator.executionAttemptFailed(0, 0, null)) + .isInstanceOf(IllegalStateException.class) + .hasMessage(failureMessage); + assertThatThrownBy(() -> dataStatisticsCoordinator.checkpointCoordinator(0, null)) + .isInstanceOf(IllegalStateException.class) + .hasMessage(failureMessage); + } + } + + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testDataStatisticsEventHandling(StatisticsType type) throws Exception { + try (DataStatisticsCoordinator dataStatisticsCoordinator = createCoordinator(type)) { + dataStatisticsCoordinator.start(); + tasksReady(dataStatisticsCoordinator); + + StatisticsEvent checkpoint1Subtask0DataStatisticEvent = + Fixtures.createStatisticsEvent( + type, + Fixtures.TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c")); + StatisticsEvent checkpoint1Subtask1DataStatisticEvent = + Fixtures.createStatisticsEvent( + type, + Fixtures.TASK_STATISTICS_SERIALIZER, + 1L, + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c")); + // Handle events from operators for checkpoint 1 + dataStatisticsCoordinator.handleEventFromOperator( + 0, 0, checkpoint1Subtask0DataStatisticEvent); + dataStatisticsCoordinator.handleEventFromOperator( + 1, 0, checkpoint1Subtask1DataStatisticEvent); + + waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + + Map keyFrequency = + ImmutableMap.of( + CHAR_KEYS.get("a"), 2L, + CHAR_KEYS.get("b"), 3L, + CHAR_KEYS.get("c"), 5L); + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(NUM_SUBTASKS, keyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + + CompletedStatistics completedStatistics = dataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics.checkpointId()).isEqualTo(1L); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()).isEqualTo(keyFrequency); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly( + CHAR_KEYS.get("a"), + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c")); + } + + GlobalStatistics globalStatistics = dataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics.checkpointId()).isEqualTo(1L); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("b")); + } + } } @Test - public void testDataStatisticsEventHandling() throws Exception { - tasksReady(); - SortKey key = sortKey.copy(); - - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - key.set(0, "a"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "b"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "b"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask0DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask0DataStatistic.add(key); - - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); - - MapDataStatistics checkpoint1Subtask1DataStatistic = new MapDataStatistics(); - key.set(0, "a"); - checkpoint1Subtask1DataStatistic.add(key); - key.set(0, "b"); - checkpoint1Subtask1DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask1DataStatistic.add(key); - key.set(0, "c"); - checkpoint1Subtask1DataStatistic.add(key); - - DataStatisticsEvent> - checkpoint1Subtask1DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, statisticsSerializer); - - // Handle events from operators for checkpoint 1 - dataStatisticsCoordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0DataStatisticEvent); - dataStatisticsCoordinator.handleEventFromOperator(1, 0, checkpoint1Subtask1DataStatisticEvent); - - waitForCoordinatorToProcessActions(dataStatisticsCoordinator); - - // Verify global data statistics is the aggregation of all subtasks data statistics - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - MapDataStatistics globalDataStatistics = - (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(globalDataStatistics.statistics()) - .containsExactlyInAnyOrderEntriesOf( - ImmutableMap.of( - keyA, 2L, - keyB, 3L, - keyC, 5L)); + public void testRequestGlobalStatisticsEventHandling() throws Exception { + try (DataStatisticsCoordinator dataStatisticsCoordinator = + createCoordinator(StatisticsType.Sketch)) { + dataStatisticsCoordinator.start(); + tasksReady(dataStatisticsCoordinator); + + // receive request before global statistics is ready + dataStatisticsCoordinator.handleEventFromOperator(0, 0, new RequestGlobalStatisticsEvent()); + assertThat(receivingTasks.getSentEventsForSubtask(0)).isEmpty(); + assertThat(receivingTasks.getSentEventsForSubtask(1)).isEmpty(); + + StatisticsEvent checkpoint1Subtask0DataStatisticEvent = + Fixtures.createStatisticsEvent( + StatisticsType.Sketch, Fixtures.TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("a")); + StatisticsEvent checkpoint1Subtask1DataStatisticEvent = + Fixtures.createStatisticsEvent( + StatisticsType.Sketch, Fixtures.TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + // Handle events from operators for checkpoint 1 + dataStatisticsCoordinator.handleEventFromOperator( + 0, 0, checkpoint1Subtask0DataStatisticEvent); + dataStatisticsCoordinator.handleEventFromOperator( + 1, 0, checkpoint1Subtask1DataStatisticEvent); + + waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + Awaitility.await("wait for statistics event") + .pollInterval(Duration.ofMillis(10)) + .atMost(Duration.ofSeconds(10)) + .until(() -> receivingTasks.getSentEventsForSubtask(0).size() == 1); + assertThat(receivingTasks.getSentEventsForSubtask(0).get(0)) + .isInstanceOf(StatisticsEvent.class); + + Awaitility.await("wait for statistics event") + .pollInterval(Duration.ofMillis(10)) + .atMost(Duration.ofSeconds(10)) + .until(() -> receivingTasks.getSentEventsForSubtask(1).size() == 1); + assertThat(receivingTasks.getSentEventsForSubtask(1).get(0)) + .isInstanceOf(StatisticsEvent.class); + + dataStatisticsCoordinator.handleEventFromOperator(1, 0, new RequestGlobalStatisticsEvent()); + + // coordinator should send a response to subtask 1 + Awaitility.await("wait for statistics event") + .pollInterval(Duration.ofMillis(10)) + .atMost(Duration.ofSeconds(10)) + .until(() -> receivingTasks.getSentEventsForSubtask(1).size() == 2); + assertThat(receivingTasks.getSentEventsForSubtask(1).get(0)) + .isInstanceOf(StatisticsEvent.class); + assertThat(receivingTasks.getSentEventsForSubtask(1).get(1)) + .isInstanceOf(StatisticsEvent.class); + } } static void setAllTasksReady( int subtasks, - DataStatisticsCoordinator> dataStatisticsCoordinator, + DataStatisticsCoordinator dataStatisticsCoordinator, EventReceivingTasks receivingTasks) { for (int i = 0; i < subtasks; i++) { dataStatisticsCoordinator.executionAttemptReady( @@ -157,8 +215,7 @@ static void setAllTasksReady( } } - static void waitForCoordinatorToProcessActions( - DataStatisticsCoordinator> coordinator) { + static void waitForCoordinatorToProcessActions(DataStatisticsCoordinator coordinator) { CompletableFuture future = new CompletableFuture<>(); coordinator.callInCoordinatorThread( () -> { @@ -175,4 +232,15 @@ static void waitForCoordinatorToProcessActions( ExceptionUtils.rethrow(ExceptionUtils.stripExecutionException(e)); } } + + private static DataStatisticsCoordinator createCoordinator(StatisticsType type) { + return new DataStatisticsCoordinator( + OPERATOR_NAME, + new MockOperatorCoordinatorContext(TEST_OPERATOR_ID, NUM_SUBTASKS), + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + NUM_SUBTASKS, + type, + 0.0d); + } } diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java index c5216eeb712a..6317f2bfde18 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java @@ -18,6 +18,10 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.TASK_STATISTICS_SERIALIZER; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.createStatisticsEvent; import static org.assertj.core.api.Assertions.assertThat; import java.util.Map; @@ -27,117 +31,157 @@ import org.apache.flink.runtime.operators.coordination.EventReceivingTasks; import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; import org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator; -import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.types.Types; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class TestDataStatisticsCoordinatorProvider { private static final OperatorID OPERATOR_ID = new OperatorID(); - private static final int NUM_SUBTASKS = 1; - private final Schema schema = - new Schema(Types.NestedField.optional(1, "str", Types.StringType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("str").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final MapDataStatisticsSerializer statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer(new SortKeySerializer(schema, sortOrder)); - - private DataStatisticsCoordinatorProvider> provider; private EventReceivingTasks receivingTasks; @BeforeEach public void before() { - provider = - new DataStatisticsCoordinatorProvider<>( - "DataStatisticsCoordinatorProvider", OPERATOR_ID, statisticsSerializer); receivingTasks = EventReceivingTasks.createForRunningTasks(); } - @Test - @SuppressWarnings("unchecked") - public void testCheckpointAndReset() throws Exception { - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - SortKey keyD = sortKey.copy(); - keyD.set(0, "c"); - SortKey keyE = sortKey.copy(); - keyE.set(0, "c"); - + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testCheckpointAndReset(StatisticsType type) throws Exception { + DataStatisticsCoordinatorProvider provider = createProvider(type, Fixtures.NUM_SUBTASKS); try (RecreateOnResetOperatorCoordinator coordinator = (RecreateOnResetOperatorCoordinator) - provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID, NUM_SUBTASKS))) { - DataStatisticsCoordinator> dataStatisticsCoordinator = - (DataStatisticsCoordinator>) - coordinator.getInternalCoordinator(); + provider.create( + new MockOperatorCoordinatorContext(OPERATOR_ID, Fixtures.NUM_SUBTASKS))) { + DataStatisticsCoordinator dataStatisticsCoordinator = + (DataStatisticsCoordinator) coordinator.getInternalCoordinator(); // Start the coordinator coordinator.start(); TestDataStatisticsCoordinator.setAllTasksReady( - NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); - MapDataStatistics checkpoint1Subtask0DataStatistic = new MapDataStatistics(); - checkpoint1Subtask0DataStatistic.add(keyA); - checkpoint1Subtask0DataStatistic.add(keyB); - checkpoint1Subtask0DataStatistic.add(keyC); - DataStatisticsEvent> - checkpoint1Subtask0DataStatisticEvent = - DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, statisticsSerializer); + Fixtures.NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks); // Handle events from operators for checkpoint 1 - coordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0DataStatisticEvent); + StatisticsEvent checkpoint1Subtask0StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("a")); + coordinator.handleEventFromOperator(0, 0, checkpoint1Subtask0StatisticsEvent); TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + + StatisticsEvent checkpoint1Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 1L, CHAR_KEYS.get("b")); + coordinator.handleEventFromOperator(1, 0, checkpoint1Subtask1StatisticsEvent); + TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + // Verify checkpoint 1 global data statistics - MapDataStatistics checkpoint1GlobalDataStatistics = - (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(checkpoint1GlobalDataStatistics.statistics()) - .isEqualTo(checkpoint1Subtask0DataStatistic.statistics()); + Map checkpoint1KeyFrequency = + ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L); + MapAssignment checkpoint1MapAssignment = + MapAssignment.fromKeyFrequency( + Fixtures.NUM_SUBTASKS, checkpoint1KeyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + + CompletedStatistics completedStatistics = dataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics).isNotNull(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()).isEqualTo(checkpoint1KeyFrequency); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + + GlobalStatistics globalStatistics = dataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics).isNotNull(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(checkpoint1MapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("a")); + } + byte[] checkpoint1Bytes = waitForCheckpoint(1L, dataStatisticsCoordinator); - MapDataStatistics checkpoint2Subtask0DataStatistic = new MapDataStatistics(); - checkpoint2Subtask0DataStatistic.add(keyD); - checkpoint2Subtask0DataStatistic.add(keyE); - checkpoint2Subtask0DataStatistic.add(keyE); - DataStatisticsEvent> - checkpoint2Subtask0DataStatisticEvent = - DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, statisticsSerializer); - // Handle events from operators for checkpoint 2 - coordinator.handleEventFromOperator(0, 0, checkpoint2Subtask0DataStatisticEvent); + StatisticsEvent checkpoint2Subtask0StatisticsEvent = + createStatisticsEvent( + type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("d"), CHAR_KEYS.get("e")); + coordinator.handleEventFromOperator(0, 0, checkpoint2Subtask0StatisticsEvent); TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + + StatisticsEvent checkpoint2Subtask1StatisticsEvent = + createStatisticsEvent(type, TASK_STATISTICS_SERIALIZER, 2L, CHAR_KEYS.get("f")); + coordinator.handleEventFromOperator(1, 0, checkpoint2Subtask1StatisticsEvent); + TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator); + // Verify checkpoint 2 global data statistics - MapDataStatistics checkpoint2GlobalDataStatistics = - (MapDataStatistics) dataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(checkpoint2GlobalDataStatistics.statistics()) - .isEqualTo(checkpoint2Subtask0DataStatistic.statistics()); + Map checkpoint2KeyFrequency = + ImmutableMap.of(CHAR_KEYS.get("d"), 1L, CHAR_KEYS.get("e"), 1L, CHAR_KEYS.get("f"), 1L); + MapAssignment checkpoint2MapAssignment = + MapAssignment.fromKeyFrequency( + Fixtures.NUM_SUBTASKS, checkpoint2KeyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + completedStatistics = dataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()).isEqualTo(checkpoint2KeyFrequency); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("d"), CHAR_KEYS.get("e"), CHAR_KEYS.get("f")); + } + + globalStatistics = dataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(checkpoint2MapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("e")); + } + waitForCheckpoint(2L, dataStatisticsCoordinator); // Reset coordinator to checkpoint 1 coordinator.resetToCheckpoint(1L, checkpoint1Bytes); - DataStatisticsCoordinator> - restoredDataStatisticsCoordinator = - (DataStatisticsCoordinator>) - coordinator.getInternalCoordinator(); - assertThat(dataStatisticsCoordinator).isNotEqualTo(restoredDataStatisticsCoordinator); + DataStatisticsCoordinator restoredDataStatisticsCoordinator = + (DataStatisticsCoordinator) coordinator.getInternalCoordinator(); + assertThat(dataStatisticsCoordinator).isNotSameAs(restoredDataStatisticsCoordinator); + + completedStatistics = restoredDataStatisticsCoordinator.completedStatistics(); + assertThat(completedStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); // Verify restored data statistics - MapDataStatistics restoredAggregateDataStatistics = - (MapDataStatistics) - restoredDataStatisticsCoordinator.completedStatistics().dataStatistics(); - assertThat(restoredAggregateDataStatistics.statistics()) - .isEqualTo(checkpoint1GlobalDataStatistics.statistics()); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(completedStatistics.keyFrequency()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 1L)); + } else { + assertThat(completedStatistics.keySamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } + + globalStatistics = restoredDataStatisticsCoordinator.globalStatistics(); + assertThat(globalStatistics).isNotNull(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(checkpoint1MapAssignment); + } else { + assertThat(globalStatistics.rangeBounds()).containsExactly(CHAR_KEYS.get("a")); + } } } - private byte[] waitForCheckpoint( - long checkpointId, - DataStatisticsCoordinator> coordinator) + private byte[] waitForCheckpoint(long checkpointId, DataStatisticsCoordinator coordinator) throws InterruptedException, ExecutionException { CompletableFuture future = new CompletableFuture<>(); coordinator.checkpointCoordinator(checkpointId, future); return future.get(); } + + private static DataStatisticsCoordinatorProvider createProvider( + StatisticsType type, int downstreamParallelism) { + return new DataStatisticsCoordinatorProvider( + "DataStatisticsCoordinatorProvider", + OPERATOR_ID, + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + downstreamParallelism, + type, + 0.0); + } } diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java index 5e6f971807ba..c760f1ba96d3 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java @@ -18,22 +18,25 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.datasketches.sampling.ReservoirItemsSketch; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.OperatorStateStore; -import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway; -import org.apache.flink.runtime.operators.testutils.MockEnvironment; -import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateInitializationContext; @@ -49,102 +52,95 @@ import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.StringData; -import org.apache.flink.table.runtime.typeutils.RowDataSerializer; -import org.apache.flink.table.types.logical.IntType; -import org.apache.flink.table.types.logical.RowType; -import org.apache.flink.table.types.logical.VarCharType; -import org.apache.iceberg.Schema; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; -import org.apache.iceberg.types.Types; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; public class TestDataStatisticsOperator { - private final Schema schema = - new Schema( - Types.NestedField.optional(1, "id", Types.StringType.get()), - Types.NestedField.optional(2, "number", Types.IntegerType.get())); - private final SortOrder sortOrder = SortOrder.builderFor(schema).asc("id").build(); - private final SortKey sortKey = new SortKey(schema, sortOrder); - private final RowType rowType = RowType.of(new VarCharType(), new IntType()); - private final TypeSerializer rowSerializer = new RowDataSerializer(rowType); - private final TypeSerializer>> - statisticsSerializer = - MapDataStatisticsSerializer.fromSortKeySerializer( - new SortKeySerializer(schema, sortOrder)); - - private DataStatisticsOperator> operator; - - private Environment getTestingEnvironment() { - return new StreamMockEnvironment( - new Configuration(), - new Configuration(), - new ExecutionConfig(), - 1L, - new MockInputSplitProvider(), - 1, - new TestTaskStateManager()); - } + + private Environment env; @BeforeEach public void before() throws Exception { - this.operator = createOperator(); - Environment env = getTestingEnvironment(); - this.operator.setup( - new OneInputStreamTask(env), - new MockStreamConfig(new Configuration(), 1), - new MockOutput<>(Lists.newArrayList())); + this.env = + new StreamMockEnvironment( + new Configuration(), + new Configuration(), + new ExecutionConfig(), + 1L, + new MockInputSplitProvider(), + 1, + new TestTaskStateManager()); } - private DataStatisticsOperator> createOperator() { + private DataStatisticsOperator createOperator(StatisticsType type, int downstreamParallelism) + throws Exception { MockOperatorEventGateway mockGateway = new MockOperatorEventGateway(); - return new DataStatisticsOperator<>( - "testOperator", schema, sortOrder, mockGateway, statisticsSerializer); + return createOperator(type, downstreamParallelism, mockGateway); } - @AfterEach - public void clean() throws Exception { - operator.close(); + private DataStatisticsOperator createOperator( + StatisticsType type, int downstreamParallelism, MockOperatorEventGateway mockGateway) + throws Exception { + DataStatisticsOperator operator = + new DataStatisticsOperator( + "testOperator", + Fixtures.SCHEMA, + Fixtures.SORT_ORDER, + mockGateway, + downstreamParallelism, + type); + operator.setup( + new OneInputStreamTask(env), + new MockStreamConfig(new Configuration(), 1), + new MockOutput<>(Lists.newArrayList())); + return operator; } - @Test - public void testProcessElement() throws Exception { - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness = createHarness(this.operator)) { + @SuppressWarnings("unchecked") + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testProcessElement(StatisticsType type) throws Exception { + DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { StateInitializationContext stateContext = getStateContext(); operator.initializeState(stateContext); operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 5))); operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 3))); operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("b"), 1))); - assertThat(operator.localDataStatistics()).isInstanceOf(MapDataStatistics.class); - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - Map expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L); - - MapDataStatistics mapDataStatistics = (MapDataStatistics) operator.localDataStatistics(); - Map statsMap = mapDataStatistics.statistics(); - assertThat(statsMap).hasSize(2); - assertThat(statsMap).containsExactlyInAnyOrderEntriesOf(expectedMap); + DataStatistics localStatistics = operator.localStatistics(); + assertThat(localStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + Map keyFrequency = (Map) localStatistics.result(); + assertThat(keyFrequency) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 1L)); + } else { + ReservoirItemsSketch sketch = + (ReservoirItemsSketch) localStatistics.result(); + assertThat(sketch.getSamples()) + .containsExactly(CHAR_KEYS.get("a"), CHAR_KEYS.get("a"), CHAR_KEYS.get("b")); + } testHarness.endInput(); } } - @Test - public void testOperatorOutput() throws Exception { - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness = createHarness(this.operator)) { + @ParameterizedTest + @EnumSource(StatisticsType.class) + public void testOperatorOutput(StatisticsType type) throws Exception { + DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { testHarness.processElement( new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 2))); testHarness.processElement( @@ -154,8 +150,8 @@ public void testOperatorOutput() throws Exception { List recordsOutput = testHarness.extractOutputValues().stream() - .filter(DataStatisticsOrRecord::hasRecord) - .map(DataStatisticsOrRecord::record) + .filter(StatisticsOrRecord::hasRecord) + .map(StatisticsOrRecord::record) .collect(Collectors.toList()); assertThat(recordsOutput) .containsExactlyInAnyOrderElementsOf( @@ -166,70 +162,172 @@ public void testOperatorOutput() throws Exception { } } - @Test - public void testRestoreState() throws Exception { + private static Stream provideRestoreStateParameters() { + return Stream.of( + Arguments.of(StatisticsType.Map, -1), + Arguments.of(StatisticsType.Map, 0), + Arguments.of(StatisticsType.Map, 1), + Arguments.of(StatisticsType.Sketch, -1), + Arguments.of(StatisticsType.Sketch, 0), + Arguments.of(StatisticsType.Sketch, 1)); + } + + @ParameterizedTest + @MethodSource("provideRestoreStateParameters") + public void testRestoreState(StatisticsType type, int parallelismAdjustment) throws Exception { + Map keyFrequency = + ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 1L, CHAR_KEYS.get("c"), 1L); + SortKey[] rangeBounds = new SortKey[] {CHAR_KEYS.get("a")}; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(2, keyFrequency, 0.0d, SORT_ORDER_COMPARTOR); + DataStatisticsOperator operator = createOperator(type, Fixtures.NUM_SUBTASKS); OperatorSubtaskState snapshot; - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness1 = createHarness(this.operator)) { - MapDataStatistics mapDataStatistics = new MapDataStatistics(); - - SortKey key = sortKey.copy(); - key.set(0, "a"); - mapDataStatistics.add(key); - key.set(0, "a"); - mapDataStatistics.add(key); - key.set(0, "b"); - mapDataStatistics.add(key); - key.set(0, "c"); - mapDataStatistics.add(key); - - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - Map expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L, keyC, 1L); - - DataStatisticsEvent> event = - DataStatisticsEvent.create(0, mapDataStatistics, statisticsSerializer); + try (OneInputStreamOperatorTestHarness testHarness1 = + createHarness(operator)) { + GlobalStatistics statistics; + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + statistics = GlobalStatistics.fromMapAssignment(1L, mapAssignment); + } else { + statistics = GlobalStatistics.fromRangeBounds(1L, rangeBounds); + } + + StatisticsEvent event = + StatisticsEvent.createGlobalStatisticsEvent( + statistics, Fixtures.GLOBAL_STATISTICS_SERIALIZER, false); operator.handleOperatorEvent(event); - assertThat(operator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class); - assertThat(operator.globalDataStatistics().statistics()) - .containsExactlyInAnyOrderEntriesOf(expectedMap); + + GlobalStatistics globalStatistics = operator.globalStatistics(); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment); + assertThat(globalStatistics.rangeBounds()).isNull(); + } else { + assertThat(globalStatistics.mapAssignment()).isNull(); + assertThat(globalStatistics.rangeBounds()).isEqualTo(rangeBounds); + } + snapshot = testHarness1.snapshot(1L, 0); } // Use the snapshot to initialize state for another new operator and then verify that the global // statistics for the new operator is same as before - DataStatisticsOperator> restoredOperator = - createOperator(); - try (OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - testHarness2 = new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) { + MockOperatorEventGateway spyGateway = Mockito.spy(new MockOperatorEventGateway()); + DataStatisticsOperator restoredOperator = + createOperator(type, Fixtures.NUM_SUBTASKS + parallelismAdjustment, spyGateway); + try (OneInputStreamOperatorTestHarness testHarness2 = + new OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) { testHarness2.setup(); testHarness2.initializeState(snapshot); - assertThat(restoredOperator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class); - // restored RowData is BinaryRowData. convert to GenericRowData for comparison - Map restoredStatistics = Maps.newHashMap(); - restoredStatistics.putAll(restoredOperator.globalDataStatistics().statistics()); + GlobalStatistics globalStatistics = restoredOperator.globalStatistics(); + // global statistics is always restored and used initially even if + // downstream parallelism changed. + assertThat(globalStatistics).isNotNull(); + // request is always sent to coordinator during initialization. + // coordinator would respond with a new global statistics that + // has range bound recomputed with new parallelism. + verify(spyGateway).sendEventToCoordinator(any(RequestGlobalStatisticsEvent.class)); + assertThat(globalStatistics.type()).isEqualTo(StatisticsUtil.collectType(type)); + if (StatisticsUtil.collectType(type) == StatisticsType.Map) { + assertThat(globalStatistics.mapAssignment()).isEqualTo(mapAssignment); + assertThat(globalStatistics.rangeBounds()).isNull(); + } else { + assertThat(globalStatistics.mapAssignment()).isNull(); + assertThat(globalStatistics.rangeBounds()).isEqualTo(rangeBounds); + } + } + } - SortKey keyA = sortKey.copy(); - keyA.set(0, "a"); - SortKey keyB = sortKey.copy(); - keyB.set(0, "b"); - SortKey keyC = sortKey.copy(); - keyC.set(0, "c"); - Map expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L, keyC, 1L); + @SuppressWarnings("unchecked") + @Test + public void testMigrationWithLocalStatsOverThreshold() throws Exception { + DataStatisticsOperator operator = createOperator(StatisticsType.Auto, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { + StateInitializationContext stateContext = getStateContext(); + operator.initializeState(stateContext); + + // add rows with unique keys + for (int i = 0; i < SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD; ++i) { + operator.processElement( + new StreamRecord<>(GenericRowData.of(StringData.fromString(String.valueOf(i)), i))); + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Map); + assertThat((Map) operator.localStatistics().result()).hasSize(i + 1); + } + + // one more item should trigger the migration to sketch stats + operator.processElement( + new StreamRecord<>(GenericRowData.of(StringData.fromString("key-trigger-migration"), 1))); + + int reservoirSize = + SketchUtil.determineOperatorReservoirSize(Fixtures.NUM_SUBTASKS, Fixtures.NUM_SUBTASKS); + + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Sketch); + ReservoirItemsSketch sketch = + (ReservoirItemsSketch) operator.localStatistics().result(); + assertThat(sketch.getK()).isEqualTo(reservoirSize); + assertThat(sketch.getN()).isEqualTo(SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD + 1); + // reservoir not full yet + assertThat(sketch.getN()).isLessThan(reservoirSize); + assertThat(sketch.getSamples()).hasSize((int) sketch.getN()); + + // add more items to saturate the reservoir + for (int i = 0; i < reservoirSize; ++i) { + operator.processElement( + new StreamRecord<>(GenericRowData.of(StringData.fromString(String.valueOf(i)), i))); + } + + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Sketch); + sketch = (ReservoirItemsSketch) operator.localStatistics().result(); + assertThat(sketch.getK()).isEqualTo(reservoirSize); + assertThat(sketch.getN()) + .isEqualTo(SketchUtil.OPERATOR_SKETCH_SWITCH_THRESHOLD + 1 + reservoirSize); + // reservoir is full now + assertThat(sketch.getN()).isGreaterThan(reservoirSize); + assertThat(sketch.getSamples()).hasSize(reservoirSize); + + testHarness.endInput(); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testMigrationWithGlobalSketchStatistics() throws Exception { + DataStatisticsOperator operator = createOperator(StatisticsType.Auto, Fixtures.NUM_SUBTASKS); + try (OneInputStreamOperatorTestHarness testHarness = + createHarness(operator)) { + StateInitializationContext stateContext = getStateContext(); + operator.initializeState(stateContext); - assertThat(restoredStatistics).containsExactlyInAnyOrderEntriesOf(expectedMap); + // started with Map stype + operator.processElement(new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 1))); + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Map); + assertThat((Map) operator.localStatistics().result()) + .isEqualTo(ImmutableMap.of(CHAR_KEYS.get("a"), 1L)); + + // received global statistics with sketch type + GlobalStatistics globalStatistics = + GlobalStatistics.fromRangeBounds( + 1L, new SortKey[] {CHAR_KEYS.get("c"), CHAR_KEYS.get("f")}); + operator.handleOperatorEvent( + StatisticsEvent.createGlobalStatisticsEvent( + globalStatistics, Fixtures.GLOBAL_STATISTICS_SERIALIZER, false)); + + int reservoirSize = + SketchUtil.determineOperatorReservoirSize(Fixtures.NUM_SUBTASKS, Fixtures.NUM_SUBTASKS); + + assertThat(operator.localStatistics().type()).isEqualTo(StatisticsType.Sketch); + ReservoirItemsSketch sketch = + (ReservoirItemsSketch) operator.localStatistics().result(); + assertThat(sketch.getK()).isEqualTo(reservoirSize); + assertThat(sketch.getN()).isEqualTo(1); + assertThat(sketch.getSamples()).isEqualTo(new SortKey[] {CHAR_KEYS.get("a")}); + + testHarness.endInput(); } } private StateInitializationContext getStateContext() throws Exception { - MockEnvironment env = new MockEnvironmentBuilder().build(); AbstractStateBackend abstractStateBackend = new HashMapStateBackend(); CloseableRegistry cancelStreamRegistry = new CloseableRegistry(); OperatorStateStore operatorStateStore = @@ -238,17 +336,14 @@ private StateInitializationContext getStateContext() throws Exception { return new StateInitializationContextImpl(null, operatorStateStore, null, null, null); } - private OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - createHarness( - final DataStatisticsOperator> - dataStatisticsOperator) - throws Exception { - - OneInputStreamOperatorTestHarness< - RowData, DataStatisticsOrRecord>> - harness = new OneInputStreamOperatorTestHarness<>(dataStatisticsOperator, 1, 1, 0); - harness.setup(new DataStatisticsOrRecordSerializer<>(statisticsSerializer, rowSerializer)); + private OneInputStreamOperatorTestHarness createHarness( + DataStatisticsOperator dataStatisticsOperator) throws Exception { + OneInputStreamOperatorTestHarness harness = + new OneInputStreamOperatorTestHarness<>( + dataStatisticsOperator, Fixtures.NUM_SUBTASKS, Fixtures.NUM_SUBTASKS, 0); + harness.setup( + new StatisticsOrRecordSerializer( + Fixtures.GLOBAL_STATISTICS_SERIALIZER, Fixtures.ROW_SERIALIZER)); harness.open(); return harness; } diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsSerializer.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsSerializer.java new file mode 100644 index 000000000000..59ce6df05d9d --- /dev/null +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsSerializer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; + +public class TestDataStatisticsSerializer extends SerializerTestBase { + @Override + protected TypeSerializer createSerializer() { + return Fixtures.TASK_STATISTICS_SERIALIZER; + } + + @Override + protected int getLength() { + return -1; + } + + @Override + protected Class getTypeClass() { + return DataStatistics.class; + } + + @Override + protected DataStatistics[] getTestData() { + return new DataStatistics[] { + new MapDataStatistics(), + Fixtures.createTaskStatistics( + StatisticsType.Map, CHAR_KEYS.get("a"), CHAR_KEYS.get("a"), CHAR_KEYS.get("b")), + new SketchDataStatistics(128), + Fixtures.createTaskStatistics( + StatisticsType.Sketch, CHAR_KEYS.get("a"), CHAR_KEYS.get("a"), CHAR_KEYS.get("b")) + }; + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsSerializer.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsSerializer.java new file mode 100644 index 000000000000..7afaf239c668 --- /dev/null +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestGlobalStatisticsSerializer.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; + +import org.apache.flink.api.common.typeutils.SerializerTestBase; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.iceberg.SortKey; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; + +public class TestGlobalStatisticsSerializer extends SerializerTestBase { + + @Override + protected TypeSerializer createSerializer() { + return Fixtures.GLOBAL_STATISTICS_SERIALIZER; + } + + @Override + protected int getLength() { + return -1; + } + + @Override + protected Class getTypeClass() { + return GlobalStatistics.class; + } + + @Override + protected GlobalStatistics[] getTestData() { + return new GlobalStatistics[] { + GlobalStatistics.fromMapAssignment( + 1L, + MapAssignment.fromKeyFrequency( + Fixtures.NUM_SUBTASKS, + ImmutableMap.of(CHAR_KEYS.get("a"), 1L, CHAR_KEYS.get("b"), 2L), + 0.0d, + SORT_ORDER_COMPARTOR)), + GlobalStatistics.fromRangeBounds(2L, new SortKey[] {CHAR_KEYS.get("a"), CHAR_KEYS.get("b")}) + }; + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java index be2beeebc93c..8a25c7ad9898 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java @@ -18,74 +18,50 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.ROW_WRAPPER; import static org.assertj.core.api.Assertions.assertThat; import java.util.Map; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.StringData; -import org.apache.flink.table.types.logical.RowType; import org.apache.iceberg.SortKey; -import org.apache.iceberg.SortOrder; -import org.apache.iceberg.flink.FlinkSchemaUtil; -import org.apache.iceberg.flink.RowDataWrapper; -import org.apache.iceberg.flink.TestFixtures; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.Test; public class TestMapDataStatistics { - private final SortOrder sortOrder = SortOrder.builderFor(TestFixtures.SCHEMA).asc("data").build(); - private final SortKey sortKey = new SortKey(TestFixtures.SCHEMA, sortOrder); - private final RowType rowType = FlinkSchemaUtil.convert(TestFixtures.SCHEMA); - private final RowDataWrapper rowWrapper = - new RowDataWrapper(rowType, TestFixtures.SCHEMA.asStruct()); - + @SuppressWarnings("unchecked") @Test public void testAddsAndGet() { MapDataStatistics dataStatistics = new MapDataStatistics(); - GenericRowData reusedRow = - GenericRowData.of(StringData.fromString("a"), 1, StringData.fromString("2023-06-20")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + GenericRowData reusedRow = GenericRowData.of(StringData.fromString("a"), 1); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("b")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("c")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("b")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("a")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); reusedRow.setField(0, StringData.fromString("b")); - sortKey.wrap(rowWrapper.wrap(reusedRow)); - dataStatistics.add(sortKey); - - Map actual = dataStatistics.statistics(); - - rowWrapper.wrap( - GenericRowData.of(StringData.fromString("a"), 1, StringData.fromString("2023-06-20"))); - sortKey.wrap(rowWrapper); - SortKey keyA = sortKey.copy(); - - rowWrapper.wrap( - GenericRowData.of(StringData.fromString("b"), 1, StringData.fromString("2023-06-20"))); - sortKey.wrap(rowWrapper); - SortKey keyB = sortKey.copy(); - - rowWrapper.wrap( - GenericRowData.of(StringData.fromString("c"), 1, StringData.fromString("2023-06-20"))); - sortKey.wrap(rowWrapper); - SortKey keyC = sortKey.copy(); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); - Map expected = ImmutableMap.of(keyA, 2L, keyB, 3L, keyC, 1L); + Map actual = (Map) dataStatistics.result(); + Map expected = + ImmutableMap.of(CHAR_KEYS.get("a"), 2L, CHAR_KEYS.get("b"), 3L, CHAR_KEYS.get("c"), 1L); assertThat(actual).isEqualTo(expected); } } diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java index e6726e7db785..d5a0bebc74e7 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapRangePartitioner.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.SORT_ORDER_COMPARTOR; import static org.assertj.core.api.Assertions.assertThat; import java.util.List; @@ -64,65 +65,60 @@ private static SortKey[] initSortKeys() { } // Total weight is 800 - private final MapDataStatistics mapDataStatistics = - new MapDataStatistics( - ImmutableMap.of( - SORT_KEYS[0], - 350L, - SORT_KEYS[1], - 230L, - SORT_KEYS[2], - 120L, - SORT_KEYS[3], - 40L, - SORT_KEYS[4], - 10L, - SORT_KEYS[5], - 10L, - SORT_KEYS[6], - 10L, - SORT_KEYS[7], - 10L, - SORT_KEYS[8], - 10L, - SORT_KEYS[9], - 10L)); + private final Map mapStatistics = + ImmutableMap.of( + SORT_KEYS[0], + 350L, + SORT_KEYS[1], + 230L, + SORT_KEYS[2], + 120L, + SORT_KEYS[3], + 40L, + SORT_KEYS[4], + 10L, + SORT_KEYS[5], + 10L, + SORT_KEYS[6], + 10L, + SORT_KEYS[7], + 10L, + SORT_KEYS[8], + 10L, + SORT_KEYS[9], + 10L); @Test public void testEvenlyDividableNoClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 0.0); int numPartitions = 8; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 0.0, SORT_ORDER_COMPARTOR); // each task should get targeted weight of 100 (=800/8) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(100L, 100L, 100L, 50L), 0L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(3, 4, 5), ImmutableList.of(50L, 100L, 80L), 0L), + new KeyAssignment(ImmutableList.of(3, 4, 5), ImmutableList.of(50L, 100L, 80L), 0L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(5, 6), ImmutableList.of(20L, 100L), 0L), + new KeyAssignment(ImmutableList.of(5, 6), ImmutableList.of(20L, 100L), 0L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(40L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(40L), 0L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(10L), 0L)); + assertThat(mapAssignment).isEqualTo(new MapAssignment(numPartitions, expectedAssignment)); // key: subtask id // value pair: first is the assigned weight, second is the number of assigned keys @@ -144,19 +140,20 @@ public void testEvenlyDividableNoClosingFileCost() { Pair.of(100L, 1), 7, Pair.of(100L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0); } @Test public void testEvenlyDividableWithClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 5.0); int numPartitions = 8; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 5.0, SORT_ORDER_COMPARTOR); // target subtask weight is 100 before close file cost factored in. // close file cost is 5 = 5% * 100. @@ -165,35 +162,30 @@ public void testEvenlyDividableWithClosingFileCost() { // close-cost: 20, 15, 10, 5, 5, 5, 5, 5, 5, 5 // after: 370, 245, 130, 45, 15, 15, 15, 15, 15, 15 // target subtask weight with close cost per subtask is 110 (880/8) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(110L, 110L, 110L, 40L), 5L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(3, 4, 5), ImmutableList.of(70L, 110L, 65L), 5L), + new KeyAssignment(ImmutableList.of(3, 4, 5), ImmutableList.of(70L, 110L, 65L), 5L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(5, 6), ImmutableList.of(45L, 85L), 5L), + new KeyAssignment(ImmutableList.of(5, 6), ImmutableList.of(45L, 85L), 5L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(6, 7), ImmutableList.of(25L, 20L), 5L), + new KeyAssignment(ImmutableList.of(6, 7), ImmutableList.of(25L, 20L), 5L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(7), ImmutableList.of(15L), 5L)); + assertThat(mapAssignment.keyAssignments()).isEqualTo(expectedAssignment); // key: subtask id // value pair: first is the assigned weight (excluding close file cost) for the subtask, @@ -216,51 +208,48 @@ public void testEvenlyDividableWithClosingFileCost() { Pair.of(100L, 2), 7, Pair.of(75L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0); } @Test public void testNonDividableNoClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 0.0); int numPartitions = 9; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 0.0, SORT_ORDER_COMPARTOR); // before: 350, 230, 120, 40, 10, 10, 10, 10, 10, 10 // each task should get targeted weight of 89 = ceiling(800/9) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(89L, 89L, 89L, 83L), 0L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(3, 4, 5, 6), ImmutableList.of(6L, 89L, 89L, 46L), 0L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(6, 7), ImmutableList.of(43L, 77L), 0L), + new KeyAssignment(ImmutableList.of(6, 7), ImmutableList.of(43L, 77L), 0L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(7, 8), ImmutableList.of(12L, 28L), 0L), + new KeyAssignment(ImmutableList.of(7, 8), ImmutableList.of(12L, 28L), 0L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(10L), 0L)); + assertThat(mapAssignment.keyAssignments()).isEqualTo(expectedAssignment); // key: subtask id // value pair: first is the assigned weight, second is the number of assigned keys @@ -284,19 +273,20 @@ public void testNonDividableNoClosingFileCost() { Pair.of(89L, 2), 8, Pair.of(88L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); validatePartitionResults(expectedAssignmentInfo, partitionResults, 5.0); } @Test public void testNonDividableWithClosingFileCost() { - MapRangePartitioner partitioner = - new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapDataStatistics, 5.0); int numPartitions = 9; + MapAssignment mapAssignment = + MapAssignment.fromKeyFrequency(numPartitions, mapStatistics, 5.0, SORT_ORDER_COMPARTOR); // target subtask weight is 89 before close file cost factored in. // close file cost is 5 (= 5% * 89) per file. @@ -305,35 +295,31 @@ public void testNonDividableWithClosingFileCost() { // close-cost: 20, 15, 10, 5, 5, 5, 5, 5, 5, 5 // after: 370, 245, 130, 45, 15, 15, 15, 15, 15, 15 // target subtask weight per subtask is 98 ceiling(880/9) - Map expectedAssignment = + Map expectedAssignment = ImmutableMap.of( SORT_KEYS[0], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(0, 1, 2, 3), ImmutableList.of(98L, 98L, 98L, 76L), 5L), SORT_KEYS[1], - new MapRangePartitioner.KeyAssignment( + new KeyAssignment( ImmutableList.of(3, 4, 5, 6), ImmutableList.of(22L, 98L, 98L, 27L), 5L), SORT_KEYS[2], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(6, 7), ImmutableList.of(71L, 59L), 5L), + new KeyAssignment(ImmutableList.of(6, 7), ImmutableList.of(71L, 59L), 5L), SORT_KEYS[3], - new MapRangePartitioner.KeyAssignment( - ImmutableList.of(7, 8), ImmutableList.of(39L, 6L), 5L), + new KeyAssignment(ImmutableList.of(7, 8), ImmutableList.of(39L, 6L), 5L), SORT_KEYS[4], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[5], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[6], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[7], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[8], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L), SORT_KEYS[9], - new MapRangePartitioner.KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L)); - Map actualAssignment = - partitioner.assignment(numPartitions); - assertThat(actualAssignment).isEqualTo(expectedAssignment); + new KeyAssignment(ImmutableList.of(8), ImmutableList.of(15L), 5L)); + assertThat(mapAssignment.keyAssignments()).isEqualTo(expectedAssignment); // key: subtask id // value pair: first is the assigned weight for the subtask, second is the number of keys @@ -358,40 +344,39 @@ public void testNonDividableWithClosingFileCost() { Pair.of(88L, 2), 8, Pair.of(61L, 7)); - Map> actualAssignmentInfo = partitioner.assignmentInfo(); - assertThat(actualAssignmentInfo).isEqualTo(expectedAssignmentInfo); + assertThat(mapAssignment.assignmentInfo()).isEqualTo(expectedAssignmentInfo); + MapRangePartitioner partitioner = + new MapRangePartitioner(TestFixtures.SCHEMA, SORT_ORDER, mapAssignment); Map>> partitionResults = - runPartitioner(partitioner, numPartitions); + runPartitioner(partitioner, numPartitions, mapStatistics); // drift threshold is high for non-dividable scenario with close cost validatePartitionResults(expectedAssignmentInfo, partitionResults, 10.0); } private static Map>> runPartitioner( - MapRangePartitioner partitioner, int numPartitions) { + MapRangePartitioner partitioner, int numPartitions, Map mapStatistics) { // The Map key is the subtaskId. // For the map value pair, the first element is the count of assigned and // the second element of Set is for the set of assigned keys. Map>> partitionResults = Maps.newHashMap(); - partitioner - .mapStatistics() - .forEach( - (sortKey, weight) -> { - String key = sortKey.get(0, String.class); - // run 100x times of the weight - long iterations = weight * 100; - for (int i = 0; i < iterations; ++i) { - RowData rowData = - GenericRowData.of( - StringData.fromString(key), 1, StringData.fromString("2023-06-20")); - int subtaskId = partitioner.partition(rowData, numPartitions); - partitionResults.computeIfAbsent( - subtaskId, k -> Pair.of(new AtomicLong(0), Sets.newHashSet())); - Pair> pair = partitionResults.get(subtaskId); - pair.first().incrementAndGet(); - pair.second().add(rowData); - } - }); + mapStatistics.forEach( + (sortKey, weight) -> { + String key = sortKey.get(0, String.class); + // run 100x times of the weight + long iterations = weight * 100; + for (int i = 0; i < iterations; ++i) { + RowData rowData = + GenericRowData.of( + StringData.fromString(key), 1, StringData.fromString("2023-06-20")); + int subtaskId = partitioner.partition(rowData, numPartitions); + partitionResults.computeIfAbsent( + subtaskId, k -> Pair.of(new AtomicLong(0), Sets.newHashSet())); + Pair> pair = partitionResults.get(subtaskId); + pair.first().incrementAndGet(); + pair.second().add(rowData); + } + }); return partitionResults; } diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchDataStatistics.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchDataStatistics.java new file mode 100644 index 000000000000..396bfae2f13c --- /dev/null +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchDataStatistics.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.ROW_WRAPPER; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.datasketches.sampling.ReservoirItemsSketch; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.StringData; +import org.apache.iceberg.SortKey; +import org.junit.jupiter.api.Test; + +public class TestSketchDataStatistics { + @SuppressWarnings("unchecked") + @Test + public void testAddsAndGet() { + SketchDataStatistics dataStatistics = new SketchDataStatistics(128); + + GenericRowData reusedRow = GenericRowData.of(StringData.fromString("a"), 1); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + reusedRow.setField(0, StringData.fromString("b")); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + reusedRow.setField(0, StringData.fromString("c")); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + reusedRow.setField(0, StringData.fromString("b")); + Fixtures.SORT_KEY.wrap(ROW_WRAPPER.wrap(reusedRow)); + dataStatistics.add(Fixtures.SORT_KEY); + + ReservoirItemsSketch actual = (ReservoirItemsSketch) dataStatistics.result(); + assertThat(actual.getSamples()) + .isEqualTo( + new SortKey[] { + CHAR_KEYS.get("a"), CHAR_KEYS.get("b"), CHAR_KEYS.get("c"), CHAR_KEYS.get("b") + }); + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchUtil.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchUtil.java new file mode 100644 index 000000000000..31dae5c76aeb --- /dev/null +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSketchUtil.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.flink.sink.shuffle; + +import static org.apache.iceberg.flink.sink.shuffle.Fixtures.CHAR_KEYS; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.iceberg.SortKey; +import org.junit.jupiter.api.Test; + +public class TestSketchUtil { + @Test + public void testCoordinatorReservoirSize() { + // adjusted to over min threshold of 10_000 and is divisible by number of partitions (3) + assertThat(SketchUtil.determineCoordinatorReservoirSize(3)).isEqualTo(10_002); + // adjust to multiplier of 100 + assertThat(SketchUtil.determineCoordinatorReservoirSize(123)).isEqualTo(123_00); + // adjusted to below max threshold of 1_000_000 and is divisible by number of partitions (3) + assertThat(SketchUtil.determineCoordinatorReservoirSize(10_123)) + .isEqualTo(1_000_000 - (1_000_000 % 10_123)); + } + + @Test + public void testOperatorReservoirSize() { + assertThat(SketchUtil.determineOperatorReservoirSize(5, 3)) + .isEqualTo((10_002 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 5); + assertThat(SketchUtil.determineOperatorReservoirSize(123, 123)) + .isEqualTo((123_00 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 123); + assertThat(SketchUtil.determineOperatorReservoirSize(256, 123)) + .isEqualTo( + (int) Math.ceil((double) (123_00 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 256)); + assertThat(SketchUtil.determineOperatorReservoirSize(5_120, 10_123)) + .isEqualTo( + (int) Math.ceil((double) (992_054 * SketchUtil.OPERATOR_OVER_SAMPLE_RATIO) / 5_120)); + } + + @Test + public void testRangeBoundsOneChannel() { + assertThat( + SketchUtil.rangeBounds( + 1, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f") + })) + .isEmpty(); + } + + @Test + public void testRangeBoundsDivisible() { + assertThat( + SketchUtil.rangeBounds( + 3, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f") + })) + .containsExactly(CHAR_KEYS.get("b"), CHAR_KEYS.get("d")); + } + + @Test + public void testRangeBoundsNonDivisible() { + // step is 3 = ceiling(11/4) + assertThat( + SketchUtil.rangeBounds( + 4, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("d"), + CHAR_KEYS.get("e"), + CHAR_KEYS.get("f"), + CHAR_KEYS.get("g"), + CHAR_KEYS.get("h"), + CHAR_KEYS.get("i"), + CHAR_KEYS.get("j"), + CHAR_KEYS.get("k"), + })) + .containsExactly(CHAR_KEYS.get("c"), CHAR_KEYS.get("f"), CHAR_KEYS.get("i")); + } + + @Test + public void testRangeBoundsSkipDuplicates() { + // step is 3 = ceiling(11/4) + assertThat( + SketchUtil.rangeBounds( + 4, + Fixtures.SORT_ORDER_COMPARTOR, + new SortKey[] { + CHAR_KEYS.get("a"), + CHAR_KEYS.get("b"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("c"), + CHAR_KEYS.get("g"), + CHAR_KEYS.get("h"), + CHAR_KEYS.get("i"), + CHAR_KEYS.get("j"), + CHAR_KEYS.get("k"), + })) + // skipped duplicate c's + .containsExactly(CHAR_KEYS.get("c"), CHAR_KEYS.get("g"), CHAR_KEYS.get("j")); + } +} diff --git a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java index 291302aef486..54cceae6e55b 100644 --- a/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java +++ b/flink/v1.18/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java @@ -18,14 +18,24 @@ */ package org.apache.iceberg.flink.sink.shuffle; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; import org.apache.iceberg.NullOrder; import org.apache.iceberg.Schema; import org.apache.iceberg.SortDirection; +import org.apache.iceberg.SortKey; import org.apache.iceberg.SortOrder; +import org.apache.iceberg.StructLike; import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.flink.DataGenerator; import org.apache.iceberg.flink.DataGenerators; +import org.apache.iceberg.flink.RowDataWrapper; +import org.junit.jupiter.api.Test; public class TestSortKeySerializerPrimitives extends TestSortKeySerializerBase { private final DataGenerator generator = new DataGenerators.Primitives(); @@ -54,4 +64,27 @@ protected SortOrder sortOrder() { protected GenericRowData rowData() { return generator.generateFlinkRowData(); } + + @Test + public void testSerializationSize() throws Exception { + RowData rowData = + GenericRowData.of(StringData.fromString("550e8400-e29b-41d4-a716-446655440000"), 1L); + RowDataWrapper rowDataWrapper = + new RowDataWrapper(Fixtures.ROW_TYPE, Fixtures.SCHEMA.asStruct()); + StructLike struct = rowDataWrapper.wrap(rowData); + SortKey sortKey = Fixtures.SORT_KEY.copy(); + sortKey.wrap(struct); + SortKeySerializer serializer = new SortKeySerializer(Fixtures.SCHEMA, Fixtures.SORT_ORDER); + DataOutputSerializer output = new DataOutputSerializer(1024); + serializer.serialize(sortKey, output); + byte[] serializedBytes = output.getCopyOfBuffer(); + assertThat(serializedBytes.length) + .as( + "Serialized bytes for sort key should be 38 bytes (34 UUID text + 4 byte integer of string length") + .isEqualTo(38); + + DataInputDeserializer input = new DataInputDeserializer(serializedBytes); + SortKey deserialized = serializer.deserialize(input); + assertThat(deserialized).isEqualTo(sortKey); + } }