From d693f83b6ad6e306550d117ff0bf0014356500b2 Mon Sep 17 00:00:00 2001 From: Manu Zhang Date: Sat, 25 Jan 2025 07:51:46 +0800 Subject: [PATCH] Spark 3.5: Fix broadcasting specs in RewriteTablePath (#11982) --- .../actions/RewriteTablePathSparkAction.java | 71 +++++++------------ .../actions/TestRewriteTablePathsAction.java | 28 ++++++++ 2 files changed, 54 insertions(+), 45 deletions(-) diff --git a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteTablePathSparkAction.java b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteTablePathSparkAction.java index 4d5d1db38e25..55888f7f5e82 100644 --- a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteTablePathSparkAction.java +++ b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/RewriteTablePathSparkAction.java @@ -36,7 +36,6 @@ import org.apache.iceberg.RewriteTablePathUtil.PositionDeleteReaderWriter; import org.apache.iceberg.RewriteTablePathUtil.RewriteResult; import org.apache.iceberg.Schema; -import org.apache.iceberg.SerializableTable; import org.apache.iceberg.Snapshot; import org.apache.iceberg.StaticTableOperations; import org.apache.iceberg.StructLike; @@ -63,10 +62,12 @@ import org.apache.iceberg.io.OutputFile; import org.apache.iceberg.orc.ORC; import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.JobGroupInfo; +import org.apache.iceberg.spark.source.SerializableTableWithSize; import org.apache.iceberg.util.Pair; import org.apache.spark.api.java.function.ForeachFunction; import org.apache.spark.api.java.function.MapFunction; @@ -96,6 +97,7 @@ public class RewriteTablePathSparkAction extends BaseSparkAction tableBroadcast = null; RewriteTablePathSparkAction(SparkSession spark, Table table) { super(spark); @@ -457,18 +459,13 @@ private RewriteContentFileResult rewriteManifests( Dataset manifestDS = spark().createDataset(Lists.newArrayList(toRewrite), manifestFileEncoder); - Broadcast serializableTable = sparkContext().broadcast(SerializableTable.copyOf(table)); - Broadcast> specsById = - sparkContext().broadcast(tableMetadata.specsById()); - return manifestDS .repartition(toRewrite.size()) .map( toManifests( - serializableTable, + tableBroadcast(), stagingDir, tableMetadata.formatVersion(), - specsById, sourcePrefix, targetPrefix), Encoders.bean(RewriteContentFileResult.class)) @@ -478,10 +475,9 @@ private RewriteContentFileResult rewriteManifests( } private static MapFunction toManifests( - Broadcast
tableBroadcast, + Broadcast
table, String stagingLocation, int format, - Broadcast> specsById, String sourcePrefix, String targetPrefix) { @@ -491,24 +487,12 @@ private static MapFunction toManifests( case DATA: result.appendDataFile( writeDataManifest( - manifestFile, - tableBroadcast, - stagingLocation, - format, - specsById, - sourcePrefix, - targetPrefix)); + manifestFile, table, stagingLocation, format, sourcePrefix, targetPrefix)); break; case DELETES: result.appendDeleteFile( writeDeleteManifest( - manifestFile, - tableBroadcast, - stagingLocation, - format, - specsById, - sourcePrefix, - targetPrefix)); + manifestFile, table, stagingLocation, format, sourcePrefix, targetPrefix)); break; default: throw new UnsupportedOperationException( @@ -520,17 +504,16 @@ private static MapFunction toManifests( private static RewriteResult writeDataManifest( ManifestFile manifestFile, - Broadcast
tableBroadcast, + Broadcast
table, String stagingLocation, int format, - Broadcast> specsByIdBroadcast, String sourcePrefix, String targetPrefix) { try { String stagingPath = RewriteTablePathUtil.stagingPath(manifestFile.path(), stagingLocation); - FileIO io = tableBroadcast.getValue().io(); + FileIO io = table.getValue().io(); OutputFile outputFile = io.newOutputFile(stagingPath); - Map specsById = specsByIdBroadcast.getValue(); + Map specsById = table.getValue().specs(); return RewriteTablePathUtil.rewriteDataManifest( manifestFile, outputFile, io, format, specsById, sourcePrefix, targetPrefix); } catch (IOException e) { @@ -540,17 +523,16 @@ private static RewriteResult writeDataManifest( private static RewriteResult writeDeleteManifest( ManifestFile manifestFile, - Broadcast
tableBroadcast, + Broadcast
table, String stagingLocation, int format, - Broadcast> specsByIdBroadcast, String sourcePrefix, String targetPrefix) { try { String stagingPath = RewriteTablePathUtil.stagingPath(manifestFile.path(), stagingLocation); - FileIO io = tableBroadcast.getValue().io(); + FileIO io = table.getValue().io(); OutputFile outputFile = io.newOutputFile(stagingPath); - Map specsById = specsByIdBroadcast.getValue(); + Map specsById = table.getValue().specs(); return RewriteTablePathUtil.rewriteDeleteManifest( manifestFile, outputFile, @@ -574,21 +556,12 @@ private void rewritePositionDeletes(TableMetadata metadata, Set toRe Dataset deleteFileDs = spark().createDataset(Lists.newArrayList(toRewrite), deleteFileEncoder); - Broadcast
serializableTable = sparkContext().broadcast(SerializableTable.copyOf(table)); - Broadcast> specsById = - sparkContext().broadcast(metadata.specsById()); - PositionDeleteReaderWriter posDeleteReaderWriter = new SparkPositionDeleteReaderWriter(); deleteFileDs .repartition(toRewrite.size()) .foreach( rewritePositionDelete( - serializableTable, - specsById, - sourcePrefix, - targetPrefix, - stagingDir, - posDeleteReaderWriter)); + tableBroadcast(), sourcePrefix, targetPrefix, stagingDir, posDeleteReaderWriter)); } private static class SparkPositionDeleteReaderWriter implements PositionDeleteReaderWriter { @@ -611,17 +584,16 @@ public PositionDeleteWriter writer( } private ForeachFunction rewritePositionDelete( - Broadcast
tableBroadcast, - Broadcast> specsById, + Broadcast
tableArg, String sourcePrefixArg, String targetPrefixArg, String stagingLocationArg, PositionDeleteReaderWriter posDeleteReaderWriter) { return deleteFile -> { - FileIO io = tableBroadcast.getValue().io(); + FileIO io = tableArg.getValue().io(); String newPath = RewriteTablePathUtil.stagingPath(deleteFile.location(), stagingLocationArg); OutputFile outputFile = io.newOutputFile(newPath); - PartitionSpec spec = specsById.getValue().get(deleteFile.specId()); + PartitionSpec spec = tableArg.getValue().specs().get(deleteFile.specId()); RewriteTablePathUtil.rewritePositionDeleteFile( deleteFile, outputFile, @@ -730,4 +702,13 @@ private String getMetadataLocation(Table tbl) { !metadataDir.isEmpty(), "Failed to get the metadata file root directory"); return metadataDir; } + + @VisibleForTesting + Broadcast
tableBroadcast() { + if (tableBroadcast == null) { + this.tableBroadcast = sparkContext().broadcast(SerializableTableWithSize.copyOf(table)); + } + + return tableBroadcast; + } } diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteTablePathsAction.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteTablePathsAction.java index dba7ff197b39..644ddcd27ef4 100644 --- a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteTablePathsAction.java +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestRewriteTablePathsAction.java @@ -64,10 +64,16 @@ import org.apache.iceberg.spark.source.ThreeColumnRecord; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.Pair; +import org.apache.spark.SparkEnv; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockInfoManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BroadcastBlockId; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -910,6 +916,17 @@ public void testDeleteFrom() throws Exception { assertEquals("Rows must match", originalData, copiedData); } + @Test + public void testKryoDeserializeBroadcastValues() { + sparkContext.getConf().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); + RewriteTablePathSparkAction action = + (RewriteTablePathSparkAction) actions().rewriteTablePath(table); + Broadcast
tableBroadcast = action.tableBroadcast(); + // force deserializing broadcast values + removeBroadcastValuesFromLocalBlockManager(tableBroadcast.id()); + assertThat(tableBroadcast.getValue().uuid()).isEqualTo(table.uuid()); + } + protected void checkFileNum( int versionFileCount, int manifestListCount, @@ -1049,4 +1066,15 @@ private PositionDelete positionDelete( posDelete.set(path, position, nested); return posDelete; } + + private void removeBroadcastValuesFromLocalBlockManager(long id) { + BlockId blockId = new BroadcastBlockId(id, ""); + SparkEnv env = SparkEnv.get(); + env.broadcastManager().cachedValues().clear(); + BlockManager blockManager = env.blockManager(); + BlockInfoManager blockInfoManager = blockManager.blockInfoManager(); + blockInfoManager.lockForWriting(blockId, true); + blockInfoManager.removeBlock(blockId); + blockManager.memoryStore().remove(blockId); + } }