Skip to content

Commit 877f82c

Browse files
xuanyuankingcloud-fan
authored andcommitted
[SPARK-26193][SQL] Implement shuffle write metrics in SQL
## What changes were proposed in this pull request? 1. Implement `SQLShuffleWriteMetricsReporter` on the SQL side as the customized `ShuffleWriteMetricsReporter`. 2. Add shuffle write metrics to `ShuffleExchangeExec`, and use these metrics to create corresponding `SQLShuffleWriteMetricsReporter` in shuffle dependency. 3. Rework on `ShuffleMapTask` to add new class named `ShuffleWriteProcessor` which control shuffle write process, we use sql shuffle write metrics by customizing a ShuffleWriteProcessor on SQL side. ## How was this patch tested? Add UT in SQLMetricsSuite. Manually test locally, update screen shot to document attached in JIRA. Closes #23207 from xuanyuanking/SPARK-26193. Authored-by: Yuanjian Li <xyliyuanjian@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 55276d3 commit 877f82c

File tree

9 files changed

+234
-40
lines changed

9 files changed

+234
-40
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
2222
import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.serializer.Serializer
25-
import org.apache.spark.shuffle.ShuffleHandle
25+
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
2626

2727
/**
2828
* :: DeveloperApi ::
@@ -65,6 +65,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
6565
* @param keyOrdering key ordering for RDD's shuffles
6666
* @param aggregator map/reduce-side aggregator for RDD's shuffle
6767
* @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
68+
* @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask
6869
*/
6970
@DeveloperApi
7071
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
@@ -73,7 +74,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
7374
val serializer: Serializer = SparkEnv.get.serializer,
7475
val keyOrdering: Option[Ordering[K]] = None,
7576
val aggregator: Option[Aggregator[K, V, C]] = None,
76-
val mapSideCombine: Boolean = false)
77+
val mapSideCombine: Boolean = false,
78+
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor)
7779
extends Dependency[Product2[K, V]] {
7880

7981
if (mapSideCombine) {

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,7 @@ private[spark] class ShuffleMapTask(
9292
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
9393
} else 0L
9494

95-
var writer: ShuffleWriter[Any, Any] = null
96-
try {
97-
val manager = SparkEnv.get.shuffleManager
98-
writer = manager.getWriter[Any, Any](
99-
dep.shuffleHandle, partitionId, context, context.taskMetrics().shuffleWriteMetrics)
100-
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
101-
writer.stop(success = true).get
102-
} catch {
103-
case e: Exception =>
104-
try {
105-
if (writer != null) {
106-
writer.stop(success = false)
107-
}
108-
} catch {
109-
case e: Exception =>
110-
log.debug("Could not stop writer", e)
111-
}
112-
throw e
113-
}
95+
dep.shuffleWriterProcessor.writeProcess(rdd, dep, partitionId, context, partition)
11496
}
11597

11698
override def preferredLocations: Seq[TaskLocation] = preferredLocs
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle
19+
20+
import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext}
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.scheduler.MapStatus
24+
25+
/**
26+
* The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
27+
* and put it into [[ShuffleDependency]], and executors use it in each ShuffleMapTask.
28+
*/
29+
private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
30+
31+
/**
32+
* Create a [[ShuffleWriteMetricsReporter]] from the task context. As the reporter is a
33+
* per-row operator, here need a careful consideration on performance.
34+
*/
35+
protected def createMetricsReporter(context: TaskContext): ShuffleWriteMetricsReporter = {
36+
context.taskMetrics().shuffleWriteMetrics
37+
}
38+
39+
/**
40+
* The write process for particular partition, it controls the life circle of [[ShuffleWriter]]
41+
* get from [[ShuffleManager]] and triggers rdd compute, finally return the [[MapStatus]] for
42+
* this task.
43+
*/
44+
def writeProcess(
45+
rdd: RDD[_],
46+
dep: ShuffleDependency[_, _, _],
47+
partitionId: Int,
48+
context: TaskContext,
49+
partition: Partition): MapStatus = {
50+
var writer: ShuffleWriter[Any, Any] = null
51+
try {
52+
val manager = SparkEnv.get.shuffleManager
53+
writer = manager.getWriter[Any, Any](
54+
dep.shuffleHandle,
55+
partitionId,
56+
context,
57+
createMetricsReporter(context))
58+
writer.write(
59+
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
60+
writer.stop(success = true).get
61+
} catch {
62+
case e: Exception =>
63+
try {
64+
if (writer != null) {
65+
writer.stop(success = false)
66+
}
67+
} catch {
68+
case e: Exception =>
69+
log.debug("Could not stop writer", e)
70+
}
71+
throw e
72+
}
73+
}
74+
}

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ object MimaExcludes {
217217
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"),
218218
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"),
219219

220+
// [SPARK-26139] Implement shuffle write metrics in SQL
221+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"),
222+
220223
// Data Source V2 API changes
221224
(problem: Problem) => problem match {
222225
case MissingClassProblem(cls) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@ import java.util.function.Supplier
2323
import org.apache.spark._
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.serializer.Serializer
26+
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
2627
import org.apache.spark.shuffle.sort.SortShuffleManager
2728
import org.apache.spark.sql.catalyst.InternalRow
2829
import org.apache.spark.sql.catalyst.errors._
2930
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
3031
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
3132
import org.apache.spark.sql.catalyst.plans.physical._
3233
import org.apache.spark.sql.execution._
33-
import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleMetricsReporter}
34+
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter}
3435
import org.apache.spark.sql.internal.SQLConf
3536
import org.apache.spark.sql.types.StructType
3637
import org.apache.spark.util.MutablePair
@@ -46,10 +47,13 @@ case class ShuffleExchangeExec(
4647

4748
// NOTE: coordinator can be null after serialization/deserialization,
4849
// e.g. it can be null on the Executor side
49-
50+
private lazy val writeMetrics =
51+
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
52+
private lazy val readMetrics =
53+
SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
5054
override lazy val metrics = Map(
5155
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")
52-
) ++ SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
56+
) ++ readMetrics ++ writeMetrics
5357

5458
override def nodeName: String = {
5559
val extraInfo = coordinator match {
@@ -90,7 +94,11 @@ case class ShuffleExchangeExec(
9094
private[exchange] def prepareShuffleDependency()
9195
: ShuffleDependency[Int, InternalRow, InternalRow] = {
9296
ShuffleExchangeExec.prepareShuffleDependency(
93-
child.execute(), child.output, newPartitioning, serializer)
97+
child.execute(),
98+
child.output,
99+
newPartitioning,
100+
serializer,
101+
writeMetrics)
94102
}
95103

96104
/**
@@ -109,7 +117,7 @@ case class ShuffleExchangeExec(
109117
assert(newPartitioning.isInstanceOf[HashPartitioning])
110118
newPartitioning = UnknownPartitioning(indices.length)
111119
}
112-
new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices)
120+
new ShuffledRowRDD(shuffleDependency, readMetrics, specifiedPartitionStartIndices)
113121
}
114122

115123
/**
@@ -204,7 +212,9 @@ object ShuffleExchangeExec {
204212
rdd: RDD[InternalRow],
205213
outputAttributes: Seq[Attribute],
206214
newPartitioning: Partitioning,
207-
serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
215+
serializer: Serializer,
216+
writeMetrics: Map[String, SQLMetric])
217+
: ShuffleDependency[Int, InternalRow, InternalRow] = {
208218
val part: Partitioner = newPartitioning match {
209219
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
210220
case HashPartitioning(_, n) =>
@@ -333,8 +343,22 @@ object ShuffleExchangeExec {
333343
new ShuffleDependency[Int, InternalRow, InternalRow](
334344
rddWithPartitionIds,
335345
new PartitionIdPassthrough(part.numPartitions),
336-
serializer)
346+
serializer,
347+
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))
337348

338349
dependency
339350
}
351+
352+
/**
353+
* Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter
354+
* with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]].
355+
*/
356+
def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = {
357+
new ShuffleWriteProcessor {
358+
override protected def createMetricsReporter(
359+
context: TaskContext): ShuffleWriteMetricsReporter = {
360+
new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics)
361+
}
362+
}
363+
}
340364
}

sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
2525
import org.apache.spark.sql.catalyst.plans.physical._
2626
import org.apache.spark.sql.catalyst.util.truncatedString
2727
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
28-
import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter
28+
import org.apache.spark.sql.execution.metric.{SQLShuffleMetricsReporter, SQLShuffleWriteMetricsReporter}
2929

3030
/**
3131
* Take the first `limit` elements and collect them to a single partition.
@@ -38,13 +38,21 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
3838
override def outputPartitioning: Partitioning = SinglePartition
3939
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
4040
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
41-
override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
41+
private lazy val writeMetrics =
42+
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
43+
private lazy val readMetrics =
44+
SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
45+
override lazy val metrics = readMetrics ++ writeMetrics
4246
protected override def doExecute(): RDD[InternalRow] = {
4347
val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
4448
val shuffled = new ShuffledRowRDD(
4549
ShuffleExchangeExec.prepareShuffleDependency(
46-
locallyLimited, child.output, SinglePartition, serializer),
47-
metrics)
50+
locallyLimited,
51+
child.output,
52+
SinglePartition,
53+
serializer,
54+
writeMetrics),
55+
readMetrics)
4856
shuffled.mapPartitionsInternal(_.take(limit))
4957
}
5058
}
@@ -154,7 +162,11 @@ case class TakeOrderedAndProjectExec(
154162

155163
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
156164

157-
override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
165+
private lazy val writeMetrics =
166+
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
167+
private lazy val readMetrics =
168+
SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
169+
override lazy val metrics = readMetrics ++ writeMetrics
158170

159171
protected override def doExecute(): RDD[InternalRow] = {
160172
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
@@ -165,8 +177,12 @@ case class TakeOrderedAndProjectExec(
165177
}
166178
val shuffled = new ShuffledRowRDD(
167179
ShuffleExchangeExec.prepareShuffleDependency(
168-
localTopK, child.output, SinglePartition, serializer),
169-
metrics)
180+
localTopK,
181+
child.output,
182+
SinglePartition,
183+
serializer,
184+
writeMetrics),
185+
readMetrics)
170186
shuffled.mapPartitions { iter =>
171187
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
172188
if (projectList != child.output) {

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.metric
2020
import java.text.NumberFormat
2121
import java.util.Locale
2222

23+
import scala.concurrent.duration._
24+
2325
import org.apache.spark.SparkContext
2426
import org.apache.spark.scheduler.AccumulableInfo
2527
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
@@ -78,6 +80,7 @@ object SQLMetrics {
7880
private val SUM_METRIC = "sum"
7981
private val SIZE_METRIC = "size"
8082
private val TIMING_METRIC = "timing"
83+
private val NS_TIMING_METRIC = "nsTiming"
8184
private val AVERAGE_METRIC = "average"
8285

8386
private val baseForAvgMetric: Int = 10
@@ -121,6 +124,13 @@ object SQLMetrics {
121124
acc
122125
}
123126

127+
def createNanoTimingMetric(sc: SparkContext, name: String): SQLMetric = {
128+
// Same with createTimingMetric, just normalize the unit of time to millisecond.
129+
val acc = new SQLMetric(NS_TIMING_METRIC, -1)
130+
acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = false)
131+
acc
132+
}
133+
124134
/**
125135
* Create a metric to report the average information (including min, med, max) like
126136
* avg hash probe. As average metrics are double values, this kind of metrics should be
@@ -163,6 +173,8 @@ object SQLMetrics {
163173
Utils.bytesToString
164174
} else if (metricsType == TIMING_METRIC) {
165175
Utils.msDurationToString
176+
} else if (metricsType == NS_TIMING_METRIC) {
177+
duration => Utils.msDurationToString(duration.nanos.toMillis)
166178
} else {
167179
throw new IllegalStateException("unexpected metrics type: " + metricsType)
168180
}

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric
1919

2020
import org.apache.spark.SparkContext
2121
import org.apache.spark.executor.TempShuffleReadMetrics
22+
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
2223

2324
/**
2425
* A shuffle metrics reporter for SQL exchange operators.
@@ -95,3 +96,57 @@ private[spark] object SQLShuffleMetricsReporter {
9596
FETCH_WAIT_TIME -> SQLMetrics.createTimingMetric(sc, "fetch wait time"),
9697
RECORDS_READ -> SQLMetrics.createMetric(sc, "records read"))
9798
}
99+
100+
/**
101+
* A shuffle write metrics reporter for SQL exchange operators.
102+
* @param metricsReporter Other reporter need to be updated in this SQLShuffleWriteMetricsReporter.
103+
* @param metrics Shuffle write metrics in current SparkPlan.
104+
*/
105+
private[spark] class SQLShuffleWriteMetricsReporter(
106+
metricsReporter: ShuffleWriteMetricsReporter,
107+
metrics: Map[String, SQLMetric]) extends ShuffleWriteMetricsReporter {
108+
private[this] val _bytesWritten =
109+
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_BYTES_WRITTEN)
110+
private[this] val _recordsWritten =
111+
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN)
112+
private[this] val _writeTime =
113+
metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)
114+
115+
override private[spark] def incBytesWritten(v: Long): Unit = {
116+
metricsReporter.incBytesWritten(v)
117+
_bytesWritten.add(v)
118+
}
119+
override private[spark] def decRecordsWritten(v: Long): Unit = {
120+
metricsReporter.decBytesWritten(v)
121+
_recordsWritten.set(_recordsWritten.value - v)
122+
}
123+
override private[spark] def incRecordsWritten(v: Long): Unit = {
124+
metricsReporter.incRecordsWritten(v)
125+
_recordsWritten.add(v)
126+
}
127+
override private[spark] def incWriteTime(v: Long): Unit = {
128+
metricsReporter.incWriteTime(v)
129+
_writeTime.add(v)
130+
}
131+
override private[spark] def decBytesWritten(v: Long): Unit = {
132+
metricsReporter.decBytesWritten(v)
133+
_bytesWritten.set(_bytesWritten.value - v)
134+
}
135+
}
136+
137+
private[spark] object SQLShuffleWriteMetricsReporter {
138+
val SHUFFLE_BYTES_WRITTEN = "shuffleBytesWritten"
139+
val SHUFFLE_RECORDS_WRITTEN = "shuffleRecordsWritten"
140+
val SHUFFLE_WRITE_TIME = "shuffleWriteTime"
141+
142+
/**
143+
* Create all shuffle write relative metrics and return the Map.
144+
*/
145+
def createShuffleWriteMetrics(sc: SparkContext): Map[String, SQLMetric] = Map(
146+
SHUFFLE_BYTES_WRITTEN ->
147+
SQLMetrics.createSizeMetric(sc, "shuffle bytes written"),
148+
SHUFFLE_RECORDS_WRITTEN ->
149+
SQLMetrics.createMetric(sc, "shuffle records written"),
150+
SHUFFLE_WRITE_TIME ->
151+
SQLMetrics.createNanoTimingMetric(sc, "shuffle write time"))
152+
}

0 commit comments

Comments
 (0)