diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
index 29405466b93f..b9460f28b4e7 100644
--- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
+++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataSplit.java
@@ -44,6 +44,7 @@
import static org.apache.paimon.io.DataFilePathFactory.INDEX_PATH_SUFFIX;
import static org.apache.paimon.utils.Preconditions.checkArgument;
+import static org.apache.paimon.utils.Preconditions.checkState;
/** Input splits. Needed by most batch computation engines. */
public class DataSplit implements Split {
@@ -126,6 +127,45 @@ public long rowCount() {
return rowCount;
}
+ /** Whether it is possible to calculate the merged row count. */
+ public boolean mergedRowCountAvailable() {
+ return rawConvertible
+ && (dataDeletionFiles == null
+ || dataDeletionFiles.stream()
+ .allMatch(f -> f == null || f.cardinality() != null));
+ }
+
+ public long mergedRowCount() {
+ checkState(mergedRowCountAvailable());
+ return partialMergedRowCount();
+ }
+
+ /**
+ * Obtain merged row count as much as possible. There are two scenarios where accurate row count
+ * can be calculated:
+ *
+ *
1. raw file and no deletion file.
+ *
+ *
2. raw file + deletion file with cardinality.
+ */
+ public long partialMergedRowCount() {
+ long sum = 0L;
+ if (rawConvertible) {
+ List rawFiles = convertToRawFiles().orElse(null);
+ if (rawFiles != null) {
+ for (int i = 0; i < rawFiles.size(); i++) {
+ RawFile rawFile = rawFiles.get(i);
+ if (dataDeletionFiles == null || dataDeletionFiles.get(i) == null) {
+ sum += rawFile.rowCount();
+ } else if (dataDeletionFiles.get(i).cardinality() != null) {
+ sum += rawFile.rowCount() - dataDeletionFiles.get(i).cardinality();
+ }
+ }
+ }
+ }
+ return sum;
+ }
+
@Override
public Optional> convertToRawFiles() {
if (rawConvertible) {
diff --git a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
index 635802cc9dcb..a4fe6d73bba1 100644
--- a/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
+++ b/paimon-core/src/main/java/org/apache/paimon/table/source/DataTableBatchScan.java
@@ -28,7 +28,6 @@
import java.util.ArrayList;
import java.util.List;
-import java.util.Objects;
import static org.apache.paimon.CoreOptions.MergeEngine.FIRST_ROW;
@@ -103,9 +102,9 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result)
List limitedSplits = new ArrayList<>();
for (DataSplit dataSplit : splits) {
if (dataSplit.rawConvertible()) {
- long splitRowCount = getRowCountForSplit(dataSplit);
+ long partialMergedRowCount = dataSplit.partialMergedRowCount();
limitedSplits.add(dataSplit);
- scannedRowCount += splitRowCount;
+ scannedRowCount += partialMergedRowCount;
if (scannedRowCount >= pushDownLimit) {
SnapshotReader.Plan newPlan =
new PlanImpl(plan.watermark(), plan.snapshotId(), limitedSplits);
@@ -117,20 +116,6 @@ private StartingScanner.Result applyPushDownLimit(StartingScanner.Result result)
return result;
}
- /**
- * 0 represents that we can't compute the row count of this split: 1. the split needs to be
- * merged; 2. the table enabled deletion vector and there are some deletion files.
- */
- private long getRowCountForSplit(DataSplit split) {
- if (split.deletionFiles().isPresent()
- && split.deletionFiles().get().stream().anyMatch(Objects::nonNull)) {
- return 0L;
- }
- return split.convertToRawFiles()
- .map(files -> files.stream().map(RawFile::rowCount).reduce(Long::sum).orElse(0L))
- .orElse(0L);
- }
-
@Override
public DataTableScan withShard(int indexOfThisSubtask, int numberOfParallelSubtasks) {
snapshotReader.withShard(indexOfThisSubtask, numberOfParallelSubtasks);
diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java
index a9e093dab124..a1f7d69e2877 100644
--- a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java
+++ b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitGeneratorTest.java
@@ -43,10 +43,10 @@
public class SplitGeneratorTest {
public static DataFileMeta newFileFromSequence(
- String name, int rowCount, long minSequence, long maxSequence) {
+ String name, int fileSize, long minSequence, long maxSequence) {
return new DataFileMeta(
name,
- rowCount,
+ fileSize,
1,
EMPTY_ROW,
EMPTY_ROW,
diff --git a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
index 359d38c973db..0219941a0ac0 100644
--- a/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
+++ b/paimon-core/src/test/java/org/apache/paimon/table/source/SplitTest.java
@@ -49,6 +49,41 @@
/** Test for {@link DataSplit}. */
public class SplitTest {
+ @Test
+ public void testSplitMergedRowCount() {
+ // not rawConvertible
+ List dataFiles =
+ Arrays.asList(newDataFile(1000L), newDataFile(2000L), newDataFile(3000L));
+ DataSplit split = newDataSplit(false, dataFiles, null);
+ assertThat(split.partialMergedRowCount()).isEqualTo(0L);
+ assertThat(split.mergedRowCountAvailable()).isEqualTo(false);
+
+ // rawConvertible without deletion files
+ split = newDataSplit(true, dataFiles, null);
+ assertThat(split.partialMergedRowCount()).isEqualTo(6000L);
+ assertThat(split.mergedRowCountAvailable()).isEqualTo(true);
+ assertThat(split.mergedRowCount()).isEqualTo(6000L);
+
+ // rawConvertible with deletion files without cardinality
+ ArrayList deletionFiles = new ArrayList<>();
+ deletionFiles.add(null);
+ deletionFiles.add(new DeletionFile("p", 1, 2, null));
+ deletionFiles.add(new DeletionFile("p", 1, 2, 100L));
+ split = newDataSplit(true, dataFiles, deletionFiles);
+ assertThat(split.partialMergedRowCount()).isEqualTo(3900L);
+ assertThat(split.mergedRowCountAvailable()).isEqualTo(false);
+
+ // rawConvertible with deletion files with cardinality
+ deletionFiles = new ArrayList<>();
+ deletionFiles.add(null);
+ deletionFiles.add(new DeletionFile("p", 1, 2, 200L));
+ deletionFiles.add(new DeletionFile("p", 1, 2, 100L));
+ split = newDataSplit(true, dataFiles, deletionFiles);
+ assertThat(split.partialMergedRowCount()).isEqualTo(5700L);
+ assertThat(split.mergedRowCountAvailable()).isEqualTo(true);
+ assertThat(split.mergedRowCount()).isEqualTo(5700L);
+ }
+
@Test
public void testSerializer() throws IOException {
DataFileTestDataGenerator gen = DataFileTestDataGenerator.builder().build();
@@ -311,4 +346,36 @@ public void testSerializerCompatibleV3() throws Exception {
InstantiationUtil.deserializeObject(v2Bytes, DataSplit.class.getClassLoader());
assertThat(actual).isEqualTo(split);
}
+
+ private DataFileMeta newDataFile(long rowCount) {
+ return DataFileMeta.forAppend(
+ "my_data_file.parquet",
+ 1024 * 1024,
+ rowCount,
+ null,
+ 0L,
+ rowCount,
+ 1,
+ Collections.emptyList(),
+ null,
+ null,
+ null);
+ }
+
+ private DataSplit newDataSplit(
+ boolean rawConvertible,
+ List dataFiles,
+ List deletionFiles) {
+ DataSplit.Builder builder = DataSplit.builder();
+ builder.withSnapshot(1)
+ .withPartition(BinaryRow.EMPTY_ROW)
+ .withBucket(1)
+ .withBucketPath("my path")
+ .rawConvertible(rawConvertible)
+ .withDataFiles(dataFiles);
+ if (deletionFiles != null) {
+ builder.withDataDeletionFiles(deletionFiles);
+ }
+ return builder.build();
+ }
}
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
index d8b66e1cd1e0..0393a1cd1578 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScanBuilder.scala
@@ -21,6 +21,7 @@ package org.apache.paimon.spark
import org.apache.paimon.predicate.PredicateBuilder
import org.apache.paimon.spark.aggregate.LocalAggregator
import org.apache.paimon.table.Table
+import org.apache.paimon.table.source.DataSplit
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates, SupportsPushDownLimit}
@@ -36,12 +37,12 @@ class PaimonScanBuilder(table: Table)
override def pushLimit(limit: Int): Boolean = {
// It is safe, since we will do nothing if it is the primary table and the split is not `rawConvertible`
pushDownLimit = Some(limit)
- // just make a best effort to push down limit
+ // just make the best effort to push down limit
false
}
override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
- // for now we only support complete push down, so there is no difference with `pushAggregation`
+ // for now, we only support complete push down, so there is no difference with `pushAggregation`
pushAggregation(aggregation)
}
@@ -66,8 +67,11 @@ class PaimonScanBuilder(table: Table)
val pushedPartitionPredicate = PredicateBuilder.and(pushedPredicates.map(_._2): _*)
readBuilder.withFilter(pushedPartitionPredicate)
}
- val scan = readBuilder.newScan()
- scan.listPartitionEntries.asScala.foreach(aggregator.update)
+ val dataSplits = readBuilder.newScan().plan().splits().asScala.map(_.asInstanceOf[DataSplit])
+ if (!dataSplits.forall(_.mergedRowCountAvailable())) {
+ return false
+ }
+ dataSplits.foreach(aggregator.update)
localScan = Some(
PaimonLocalScan(
aggregator.result(),
diff --git a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
index 41e7fd3c3ce9..8988e7218d1f 100644
--- a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
+++ b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/aggregate/LocalAggregator.scala
@@ -19,10 +19,10 @@
package org.apache.paimon.spark.aggregate
import org.apache.paimon.data.BinaryRow
-import org.apache.paimon.manifest.PartitionEntry
import org.apache.paimon.spark.SparkTypeUtils
import org.apache.paimon.spark.data.SparkInternalRow
import org.apache.paimon.table.{DataTable, Table}
+import org.apache.paimon.table.source.DataSplit
import org.apache.paimon.utils.{InternalRowUtils, ProjectedRow}
import org.apache.spark.sql.catalyst.InternalRow
@@ -78,13 +78,7 @@ class LocalAggregator(table: Table) {
}
def pushAggregation(aggregation: Aggregation): Boolean = {
- if (
- !table.isInstanceOf[DataTable] ||
- !table.primaryKeys.isEmpty
- ) {
- return false
- }
- if (table.asInstanceOf[DataTable].coreOptions.deletionVectorsEnabled) {
+ if (!table.isInstanceOf[DataTable]) {
return false
}
@@ -108,12 +102,12 @@ class LocalAggregator(table: Table) {
SparkInternalRow.create(partitionType).replace(genericRow)
}
- def update(partitionEntry: PartitionEntry): Unit = {
+ def update(dataSplit: DataSplit): Unit = {
assert(isInitialized)
- val groupByRow = requiredGroupByRow(partitionEntry.partition())
+ val groupByRow = requiredGroupByRow(dataSplit.partition())
val aggFuncEvaluator =
groupByEvaluatorMap.getOrElseUpdate(groupByRow, aggFuncEvaluatorGetter())
- aggFuncEvaluator.foreach(_.update(partitionEntry))
+ aggFuncEvaluator.foreach(_.update(dataSplit))
}
def result(): Array[InternalRow] = {
@@ -147,7 +141,7 @@ class LocalAggregator(table: Table) {
}
trait AggFuncEvaluator[T] {
- def update(partitionEntry: PartitionEntry): Unit
+ def update(dataSplit: DataSplit): Unit
def result(): T
def resultType: DataType
def prettyName: String
@@ -156,8 +150,8 @@ trait AggFuncEvaluator[T] {
class CountStarEvaluator extends AggFuncEvaluator[Long] {
private var _result: Long = 0L
- override def update(partitionEntry: PartitionEntry): Unit = {
- _result += partitionEntry.recordCount()
+ override def update(dataSplit: DataSplit): Unit = {
+ _result += dataSplit.mergedRowCount()
}
override def result(): Long = _result
diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
index ba314e3afa81..503f1c8e3e9d 100644
--- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
+++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PaimonPushDownTest.scala
@@ -18,7 +18,7 @@
package org.apache.paimon.spark.sql
-import org.apache.paimon.spark.{PaimonBatch, PaimonInputPartition, PaimonScan, PaimonSparkTestBase, SparkTable}
+import org.apache.paimon.spark.{PaimonScan, PaimonSparkTestBase, SparkTable}
import org.apache.paimon.table.source.DataSplit
import org.apache.spark.sql.Row
@@ -29,8 +29,6 @@ import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.junit.jupiter.api.Assertions
-import scala.collection.JavaConverters._
-
class PaimonPushDownTest extends PaimonSparkTestBase {
import testImplicits._
@@ -64,7 +62,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Nil)
// case 2
- // filter "id = '1' or pt = 'p1'" can't push down completely, it still need to be evaluated after scanning
+ // filter "id = '1' or pt = 'p1'" can't push down completely, it still needs to be evaluated after scanning
q = "SELECT * FROM T WHERE id = '1' or pt = 'p1'"
Assertions.assertTrue(checkEqualToFilterExists(q, "pt", Literal("p1")))
checkAnswer(spark.sql(q), Row(1, "a", "p1") :: Row(2, "b", "p1") :: Nil)
@@ -121,7 +119,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
val dataSplitsWithoutLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
Assertions.assertTrue(dataSplitsWithoutLimit.length >= 2)
- // It still return false even it can push down limit.
+ // It still returns false even it can push down limit.
Assertions.assertFalse(scanBuilder.asInstanceOf[SupportsPushDownLimit].pushLimit(1))
val dataSplitsWithLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
Assertions.assertEquals(1, dataSplitsWithLimit.length)
@@ -169,12 +167,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
// Now, we have 4 dataSplits, and 2 dataSplit is nonRawConvertible, 2 dataSplit is rawConvertible.
Assertions.assertEquals(
2,
- dataSplitsWithoutLimit2
- .filter(
- split => {
- split.asInstanceOf[DataSplit].rawConvertible()
- })
- .length)
+ dataSplitsWithoutLimit2.count(split => { split.asInstanceOf[DataSplit].rawConvertible() }))
// Return 2 dataSplits.
Assertions.assertFalse(scanBuilder2.asInstanceOf[SupportsPushDownLimit].pushLimit(2))
@@ -206,7 +199,40 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
// Need to scan all dataSplits.
Assertions.assertEquals(4, dataSplitsWithLimit3.length)
Assertions.assertEquals(1, spark.sql("SELECT * FROM T LIMIT 1").count())
+ }
+ test("Paimon pushDown: limit for table with deletion vector") {
+ Seq(true, false).foreach(
+ deletionVectorsEnabled => {
+ Seq(true, false).foreach(
+ primaryKeyTable => {
+ withTable("T") {
+ sql(s"""
+ |CREATE TABLE T (id INT)
+ |TBLPROPERTIES (
+ | 'deletion-vectors.enabled' = $deletionVectorsEnabled,
+ | '${if (primaryKeyTable) "primary-key" else "bucket-key"}' = 'id',
+ | 'bucket' = '10'
+ |)
+ |""".stripMargin)
+
+ sql("INSERT INTO T SELECT id FROM range (1, 50000)")
+ sql("DELETE FROM T WHERE id % 13 = 0")
+
+ val withoutLimit = getScanBuilder().build().asInstanceOf[PaimonScan].getOriginSplits
+ assert(withoutLimit.length == 10)
+
+ val scanBuilder = getScanBuilder().asInstanceOf[SupportsPushDownLimit]
+ scanBuilder.pushLimit(1)
+ val withLimit = scanBuilder.build().asInstanceOf[PaimonScan].getOriginSplits
+ if (deletionVectorsEnabled || !primaryKeyTable) {
+ assert(withLimit.length == 1)
+ } else {
+ assert(withLimit.length == 10)
+ }
+ }
+ })
+ })
}
test("Paimon pushDown: runtime filter") {
@@ -250,8 +276,7 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
}
private def getScanBuilder(tableName: String = "T"): ScanBuilder = {
- new SparkTable(loadTable(tableName))
- .newScanBuilder(CaseInsensitiveStringMap.empty())
+ SparkTable(loadTable(tableName)).newScanBuilder(CaseInsensitiveStringMap.empty())
}
private def checkFilterExists(sql: String): Boolean = {
@@ -272,5 +297,4 @@ class PaimonPushDownTest extends PaimonSparkTestBase {
case _ => false
}
}
-
}
diff --git a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
index 501e7bfb4a51..78c02644a7ce 100644
--- a/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
+++ b/paimon-spark/paimon-spark-ut/src/test/scala/org/apache/paimon/spark/sql/PushDownAggregatesTest.scala
@@ -117,22 +117,58 @@ class PushDownAggregatesTest extends PaimonSparkTestBase with AdaptiveSparkPlanH
}
}
- test("Push down aggregate - primary table") {
- withTable("T") {
- spark.sql("CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES ('primary-key' = 'c1')")
- runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 2)
- spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')")
- runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2)
- }
+ test("Push down aggregate - primary key table with deletion vector") {
+ Seq(true, false).foreach(
+ deletionVectorsEnabled => {
+ withTable("T") {
+ spark.sql(s"""
+ |CREATE TABLE T (c1 INT, c2 STRING)
+ |TBLPROPERTIES (
+ |'primary-key' = 'c1',
+ |'deletion-vectors.enabled' = $deletionVectorsEnabled
+ |)
+ |""".stripMargin)
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 0)
+
+ spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')")
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 0)
+
+ spark.sql("INSERT INTO T VALUES(1, 'x_1')")
+ if (deletionVectorsEnabled) {
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 0)
+ } else {
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(3) :: Nil, 2)
+ }
+ }
+ })
}
- test("Push down aggregate - enable deletion vector") {
- withTable("T") {
- spark.sql(
- "CREATE TABLE T (c1 INT, c2 STRING) TBLPROPERTIES('deletion-vectors.enabled' = 'true')")
- runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(0) :: Nil, 2)
- spark.sql("INSERT INTO T VALUES(1, 'x'), (2, 'x'), (3, 'x'), (3, 'x')")
- runAndCheckAggregate("SELECT COUNT(*) FROM T", Row(4) :: Nil, 2)
- }
+ test("Push down aggregate - table with deletion vector") {
+ Seq(true, false).foreach(
+ deletionVectorsEnabled => {
+ Seq(true, false).foreach(
+ primaryKeyTable => {
+ withTable("T") {
+ sql(s"""
+ |CREATE TABLE T (id INT)
+ |TBLPROPERTIES (
+ | 'deletion-vectors.enabled' = $deletionVectorsEnabled,
+ | '${if (primaryKeyTable) "primary-key" else "bucket-key"}' = 'id',
+ | 'bucket' = '1'
+ |)
+ |""".stripMargin)
+
+ sql("INSERT INTO T SELECT id FROM range (0, 5000)")
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Seq(Row(5000)), 0)
+
+ sql("DELETE FROM T WHERE id > 100 and id <= 400")
+ if (deletionVectorsEnabled || !primaryKeyTable) {
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Seq(Row(4700)), 0)
+ } else {
+ runAndCheckAggregate("SELECT COUNT(*) FROM T", Seq(Row(4700)), 2)
+ }
+ }
+ })
+ })
}
}