Skip to content

Commit

Permalink
[SPARK-47050][SQL] Collect and publish partition level metrics
Browse files Browse the repository at this point in the history
Capture the partition sub-paths, along with the number of files, bytes, and rows per partition for each task.
  • Loading branch information
Steve Vaughan Jr committed Apr 23, 2024
1 parent b9f2270 commit 2341c22
Show file tree
Hide file tree
Showing 11 changed files with 441 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.write;

import java.io.Serializable;
import java.util.Collections;
import java.util.Map;
import java.util.TreeMap;

/**
* An aggregator of partition metrics collected during write operations.
* <p>
* This is patterned after {@code org.apache.spark.util.AccumulatorV2}
* </p>
*/
public class PartitionMetricsWriteInfo implements Serializable {

private final Map<String, PartitionMetrics> metrics = new TreeMap<>();

/**
* Merges another same-type accumulator into this one and update its state, i.e. this should be
* merge-in-place.
*
* @param otherAccumulator Another object containing aggregated partition metrics
*/
public void merge(PartitionMetricsWriteInfo otherAccumulator) {
otherAccumulator.metrics.forEach((p, m) ->
metrics.computeIfAbsent(p, key -> new PartitionMetrics(0L, 0L, 0))
.merge(m));
}

/**
* Update the partition metrics for the specified path by adding to the existing state. This will
* add the partition if it has not been referenced previously.
*
* @param partitionPath The path for the written partition
* @param bytes The number of additional bytes
* @param records the number of addition records
* @param files the number of additional files
*/
public void update(String partitionPath, long bytes, long records, int files) {
metrics.computeIfAbsent(partitionPath, key -> new PartitionMetrics(0L, 0L, 0))
.merge(new PartitionMetrics(bytes, records, files));
}

/**
* Update the partition metrics for the specified path by adding to the existing state from an
* individual file. This will add the partition if it has not been referenced previously.
*
* @param partitionPath The path for the written partition
* @param bytes The number of additional bytes
* @param records the number of addition records
*/
public void updateFile(String partitionPath, long bytes, long records) {
update(partitionPath, bytes, records, 1);
}

/**
* Convert this instance into an immutable {@code java.util.Map}. This is used for posting to the
* listener bus
*
* @return an immutable map of partition paths to their metrics
*/
public Map<String, PartitionMetrics> toMap() {
return Collections.unmodifiableMap(metrics);
}

/**
* Returns if this accumulator is zero value or not. For a map accumulator this indicates if the
* map is empty.
*
* @return {@code true} if there are no partition metrics
*/
boolean isZero() {
return metrics.isEmpty();
}

@Override
public String toString() {
return "PartitionMetricsWriteInfo{" +
"metrics=" + metrics +
'}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.write

/**
* The metrics collected for an individual partition
*
* @param numBytes the number of bytes
* @param numRecords the number of records (rows)
* @param numFiles the number of files
*/
case class PartitionMetrics(var numBytes: Long = 0, var numRecords: Long = 0, var numFiles: Int = 0)
extends Serializable {

/**
* Updates the metrics for an individual file.
*
* @param bytes the number of bytes
* @param records the number of records (rows)
*/
def updateFile(bytes: Long, records: Long): Unit = {
numBytes += bytes
numRecords += records
numFiles += 1
}

/**
* Merges another same-type accumulator into this one and update its state, i.e. this should be
* merge-in-place.
* @param other Another set of metrics for the same partition
*/
def merge (other: PartitionMetrics): Unit = {
numBytes += other.numBytes
numRecords += other.numRecords
numFiles += other.numFiles
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.write

import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler.SparkListenerEvent

@DeveloperApi
case class SparkListenerSQLPartitionMetrics(executorId: Long,
metrics: java.util.Map[String, PartitionMetrics])
extends SparkListenerEvent

object SQLPartitionMetrics {

/**
* Post any aggregated partition write statistics to the listener bus using a
* [[SparkListenerSQLPartitionMetrics]] event
*
* @param sc The Spark context
* @param executionId The identifier for the SQL execution that resulted in the partition writes
* @param writeInfo The aggregated partition writes for this SQL exectuion
*/
def postDriverMetricUpdates(sc: SparkContext, executionId: String,
writeInfo: PartitionMetricsWriteInfo): Unit = {
// Don't bother firing an event if there are no collected metrics
if (writeInfo.isZero) {
return
}

// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
// directly without setting an execution id. We should be tolerant to it.
if (executionId != null) {
sc.listenerBus.post(
SparkListenerSQLPartitionMetrics(executionId.toLong, writeInfo.toMap))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKey.{ACTUAL_NUM_FILES, EXPECTED_NUM_FILES}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.{PartitionMetricsWriteInfo, SQLPartitionMetrics}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand All @@ -43,10 +44,18 @@ case class BasicWriteTaskStats(
partitions: Seq[InternalRow],
numFiles: Int,
numBytes: Long,
numRows: Long)
numRows: Long,
partitionsStats: Map[InternalRow, BasicWritePartitionTaskStats]
= Map[InternalRow, BasicWritePartitionTaskStats]())
extends WriteTaskStats


case class BasicWritePartitionTaskStats(
numFiles: Int,
numBytes: Long,
numRows: Long)
extends PartitionTaskStats

/**
* Simple [[WriteTaskStatsTracker]] implementation that produces [[BasicWriteTaskStats]].
*/
Expand All @@ -56,12 +65,19 @@ class BasicWriteTaskStatsTracker(
extends WriteTaskStatsTracker with Logging {

private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty
// Map each partition to counts of the number of files, bytes, and rows written
// partition -> (files, bytes, rows)
private[this] val partitionsStats: mutable.Map[InternalRow, (Int, Long, Long)] =
mutable.Map.empty.withDefaultValue((0, 0L, 0L))
private[this] var numFiles: Int = 0
private[this] var numSubmittedFiles: Int = 0
private[this] var numBytes: Long = 0L
private[this] var numRows: Long = 0L

private[this] val submittedFiles = mutable.HashSet[String]()
private[this] val submittedPartitionFiles = mutable.Map[String, InternalRow]()

private[this] val numFileRows: mutable.Map[String, Long] = mutable.Map.empty.withDefaultValue(0)

/**
* Get the size of the file expected to have been written by a worker.
Expand Down Expand Up @@ -138,25 +154,45 @@ class BasicWriteTaskStatsTracker(
partitions.append(partitionValues)
}

override def newFile(filePath: String): Unit = {
override def newFile(filePath: String, partitionValues: Option[InternalRow] = None): Unit = {
submittedFiles += filePath
numSubmittedFiles += 1

// Submitting a file for a partition
if (partitionValues.isDefined) {
submittedPartitionFiles += (filePath -> partitionValues.get)
}
}

override def closeFile(filePath: String): Unit = {
updateFileStats(filePath)
submittedFiles.remove(filePath)
submittedPartitionFiles.remove(filePath)
numFileRows.remove(filePath)
}

private def updateFileStats(filePath: String): Unit = {
getFileSize(filePath).foreach { len =>
numBytes += len
numFiles += 1

submittedPartitionFiles.get(filePath)
.foreach(partition => {
val stats = partitionsStats(partition)
partitionsStats(partition) = stats.copy(
stats._1 + 1,
stats._2 + len,
stats._3 + numFileRows(filePath))
})
}
}

override def newRow(filePath: String, row: InternalRow): Unit = {
numRows += 1

// Track the number of rows added to each file, which may be accumulated with an associated
// partition
numFileRows(filePath) += 1
}

override def getFinalStats(taskCommitTime: Long): WriteTaskStats = {
Expand All @@ -172,7 +208,15 @@ class BasicWriteTaskStatsTracker(
log"writing empty files, or files being not immediately visible in the filesystem.")
}
taskCommitTimeMetric.foreach(_ += taskCommitTime)
BasicWriteTaskStats(partitions.toSeq, numFiles, numBytes, numRows)

val publish: ((InternalRow, (Int, Long, Long))) =>
(InternalRow, BasicWritePartitionTaskStats) = {
case (key, value) =>
val newValue = BasicWritePartitionTaskStats(value._1, value._2, value._3)
key -> newValue
}
BasicWriteTaskStats(partitions.toSeq, numFiles, numBytes, numRows,
partitionsStats.map(publish).toMap)
}
}

Expand All @@ -189,6 +233,8 @@ class BasicWriteJobStatsTracker(
taskCommitTimeMetric: SQLMetric)
extends WriteJobStatsTracker {

val partitionMetrics: PartitionMetricsWriteInfo = new PartitionMetricsWriteInfo()

def this(
serializableHadoopConf: SerializableConfiguration,
metrics: Map[String, SQLMetric]) = {
Expand All @@ -199,7 +245,8 @@ class BasicWriteJobStatsTracker(
new BasicWriteTaskStatsTracker(serializableHadoopConf.value, Some(taskCommitTimeMetric))
}

override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = {
override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long,
partitionsMap: Map[InternalRow, String]): Unit = {
val sparkContext = SparkContext.getActive.get
val partitionsSet: mutable.Set[InternalRow] = mutable.HashSet.empty
var numFiles: Long = 0L
Expand All @@ -213,6 +260,14 @@ class BasicWriteJobStatsTracker(
numFiles += summary.numFiles
totalNumBytes += summary.numBytes
totalNumOutput += summary.numRows

summary.partitionsStats.foreach(s => {
// Check if we know the mapping of the internal row to a partition path
if (partitionsMap.contains(s._1)) {
val path = partitionsMap(s._1)
partitionMetrics.update(path, s._2.numBytes, s._2.numRows, s._2.numFiles)
}
})
}

driverSideMetrics(JOB_COMMIT_TIME).add(jobCommitTime)
Expand All @@ -223,6 +278,9 @@ class BasicWriteJobStatsTracker(

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverSideMetrics.values.toList)

SQLPartitionMetrics.postDriverMetricUpdates(sparkContext, executionId,
partitionMetrics)
}
}

Expand All @@ -247,4 +305,7 @@ object BasicWriteJobStatsTracker {
JOB_COMMIT_TIME -> SQLMetrics.createTimingMetric(sparkContext, "job commit time")
)
}

def partitionMetrics: mutable.Map[String, PartitionTaskStats] =
mutable.Map.empty.withDefaultValue(BasicWritePartitionTaskStats(0, 0L, 0L))
}
Loading

0 comments on commit 2341c22

Please sign in to comment.