Skip to content

Commit 1474ed0

Browse files
Marcelo Vanzindongjoon-hyun
Marcelo Vanzin
authored andcommitted
[SPARK-29562][SQL] Speed up and slim down metric aggregation in SQL listener
First, a bit of background on the code being changed. The current code tracks metric updates for each task, recording which metrics the task is monitoring and the last update value. Once a SQL execution finishes, then the metrics for all the stages are aggregated, by building a list with all (metric ID, value) pairs collected for all tasks in the stages related to the execution, then grouping by metric ID, and then calculating the values shown in the UI. That is full of inefficiencies: - in normal operation, all tasks will be tracking and updating the same metrics. So recording the metric IDs per task is wasteful. - tracking by task means we might be double-counting values if you have speculative tasks (as a comment in the code mentions). - creating a list of (metric ID, value) is extremely inefficient, because now you have a huge map in memory storing boxed versions of the metric IDs and values. - same thing for the aggregation part, where now a Seq is built with the values for each metric ID. The end result is that for large queries, this code can become both really slow, thus affecting the processing of events, and memory hungry. The updated code changes the approach to the following: - stages track metrics by their ID; this means the stage tracking code naturally groups values, making aggregation later simpler. - each metric ID being tracked uses a long array matching the number of partitions of the stage; this means that it's cheap to update the value of the metric once a task ends. - when aggregating, custom code just concatenates the arrays corresponding to the matching metric IDs; this is cheaper than the previous, boxing-heavy approach. The end result is that the listener uses about half as much memory as before for tracking metrics, since it doesn't need to track metric IDs per task. I captured heap dumps with the old and the new code during metric aggregation in the listener, for an execution with 3 stages, 100k tasks per stage, 50 metrics updated per task. The dumps contained just reachable memory - so data kept by the listener plus the variables in the aggregateMetrics() method. With the old code, the thread doing aggregation references >1G of memory - and that does not include temporary data created by the "groupBy" transformation (for which the intermediate state is not referenced in the aggregation method). The same thread with the new code references ~250M of memory. The old code uses about ~250M to track all the metric values for that execution, while the new code uses about ~130M. (Note the per-thread numbers include the amount used to track the metrics - so, e.g., in the old case, aggregation was referencing about ~750M of temporary data.) I'm also including a small benchmark (based on the Benchmark class) so that we can measure how much changes to this code affect performance. The benchmark contains some extra code to measure things the normal Benchmark class does not, given that the code under test does not really map that well to the expectations of that class. Running with the old code (I removed results that don't make much sense for this benchmark): ``` [info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic [info] Intel(R) Core(TM) i7-6820HQ CPU 2.70GHz [info] metrics aggregation (50 metrics, 100k tasks per stage): Best Time(ms) Avg Time(ms) [info] -------------------------------------------------------------------------------------- [info] 1 stage(s) 2113 2118 [info] 2 stage(s) 4172 4392 [info] 3 stage(s) 7755 8460 [info] [info] Stage Count Stage Proc. Time Aggreg. Time [info] 1 614 1187 [info] 2 620 2480 [info] 3 718 5069 ``` With the new code: ``` [info] Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic [info] Intel(R) Core(TM) i7-6820HQ CPU 2.70GHz [info] metrics aggregation (50 metrics, 100k tasks per stage): Best Time(ms) Avg Time(ms) [info] -------------------------------------------------------------------------------------- [info] 1 stage(s) 727 886 [info] 2 stage(s) 1722 1983 [info] 3 stage(s) 2752 3013 [info] [info] Stage Count Stage Proc. Time Aggreg. Time [info] 1 408 177 [info] 2 389 423 [info] 3 372 660 ``` So the new code is faster than the old when processing task events, and about an order of maginute faster when aggregating metrics. Note this still leaves room for improvement; for example, using the above measurements, 600ms is still a huge amount of time to spend in an event handler. But I'll leave further enhancements for a separate change. Tested with benchmarking code + existing unit tests. Closes #26218 from vanzin/SPARK-29562. Authored-by: Marcelo Vanzin <vanzin@cloudera.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 7417c3e commit 1474ed0

File tree

7 files changed

+397
-72
lines changed

7 files changed

+397
-72
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
OpenJDK 64-Bit Server VM 11.0.4+11 on Linux 4.15.0-66-generic
2+
Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
3+
metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
4+
------------------------------------------------------------------------------------------------------------------------
5+
1 stage(s) 672 841 179 0.0 671888474.0 1.0X
6+
2 stage(s) 1700 1842 201 0.0 1699591662.0 0.4X
7+
3 stage(s) 2601 2776 247 0.0 2601465786.0 0.3X
8+
9+
Stage Count Stage Proc. Time Aggreg. Time
10+
1 436 164
11+
2 537 354
12+
3 480 602
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Linux 4.15.0-66-generic
2+
Intel(R) Core(TM) i7-6820HQ CPU @ 2.70GHz
3+
metrics aggregation (50 metrics, 100000 tasks per stage): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
4+
------------------------------------------------------------------------------------------------------------------------
5+
1 stage(s) 740 883 147 0.0 740089816.0 1.0X
6+
2 stage(s) 1661 1943 399 0.0 1660649192.0 0.4X
7+
3 stage(s) 2711 2967 362 0.0 2711110178.0 0.3X
8+
9+
Stage Count Stage Proc. Time Aggreg. Time
10+
1 405 179
11+
2 375 414
12+
3 364 644

sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.metric
1919

2020
import java.text.NumberFormat
21-
import java.util.Locale
21+
import java.util.{Arrays, Locale}
2222

2323
import scala.concurrent.duration._
2424

@@ -150,7 +150,7 @@ object SQLMetrics {
150150
* A function that defines how we aggregate the final accumulator results among all tasks,
151151
* and represent it in string for a SQL physical operator.
152152
*/
153-
def stringValue(metricsType: String, values: Seq[Long]): String = {
153+
def stringValue(metricsType: String, values: Array[Long]): String = {
154154
if (metricsType == SUM_METRIC) {
155155
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
156156
numberFormat.format(values.sum)
@@ -162,8 +162,9 @@ object SQLMetrics {
162162
val metric = if (validValues.isEmpty) {
163163
Seq.fill(3)(0L)
164164
} else {
165-
val sorted = validValues.sorted
166-
Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
165+
Arrays.sort(validValues)
166+
Seq(validValues(0), validValues(validValues.length / 2),
167+
validValues(validValues.length - 1))
167168
}
168169
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
169170
}
@@ -184,8 +185,9 @@ object SQLMetrics {
184185
val metric = if (validValues.isEmpty) {
185186
Seq.fill(4)(0L)
186187
} else {
187-
val sorted = validValues.sorted
188-
Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
188+
Arrays.sort(validValues)
189+
Seq(validValues.sum, validValues(0), validValues(validValues.length / 2),
190+
validValues(validValues.length - 1))
189191
}
190192
metric.map(strFormat)
191193
}

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala

Lines changed: 131 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
*/
1717
package org.apache.spark.sql.execution.ui
1818

19-
import java.util.{Date, NoSuchElementException}
19+
import java.util.{Arrays, Date, NoSuchElementException}
2020
import java.util.concurrent.ConcurrentHashMap
2121

2222
import scala.collection.JavaConverters._
23+
import scala.collection.mutable
2324

2425
import org.apache.spark.{JobExecutionStatus, SparkConf}
2526
import org.apache.spark.internal.Logging
@@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.SQLExecution
2930
import org.apache.spark.sql.execution.metric._
3031
import org.apache.spark.sql.internal.StaticSQLConf._
3132
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
33+
import org.apache.spark.util.collection.OpenHashMap
3234

3335
class SQLAppStatusListener(
3436
conf: SparkConf,
@@ -103,8 +105,10 @@ class SQLAppStatusListener(
103105
// Record the accumulator IDs for the stages of this job, so that the code that keeps
104106
// track of the metrics knows which accumulators to look at.
105107
val accumIds = exec.metrics.map(_.accumulatorId).toSet
106-
event.stageIds.foreach { id =>
107-
stageMetrics.put(id, new LiveStageMetrics(id, 0, accumIds, new ConcurrentHashMap()))
108+
if (accumIds.nonEmpty) {
109+
event.stageInfos.foreach { stage =>
110+
stageMetrics.put(stage.stageId, new LiveStageMetrics(0, stage.numTasks, accumIds))
111+
}
108112
}
109113

110114
exec.jobs = exec.jobs + (jobId -> JobExecutionStatus.RUNNING)
@@ -118,9 +122,11 @@ class SQLAppStatusListener(
118122
}
119123

120124
// Reset the metrics tracking object for the new attempt.
121-
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { metrics =>
122-
metrics.taskMetrics.clear()
123-
metrics.attemptId = event.stageInfo.attemptNumber
125+
Option(stageMetrics.get(event.stageInfo.stageId)).foreach { stage =>
126+
if (stage.attemptId != event.stageInfo.attemptNumber) {
127+
stageMetrics.put(event.stageInfo.stageId,
128+
new LiveStageMetrics(event.stageInfo.attemptNumber, stage.numTasks, stage.accumulatorIds))
129+
}
124130
}
125131
}
126132

@@ -140,7 +146,16 @@ class SQLAppStatusListener(
140146

141147
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = {
142148
event.accumUpdates.foreach { case (taskId, stageId, attemptId, accumUpdates) =>
143-
updateStageMetrics(stageId, attemptId, taskId, accumUpdates, false)
149+
updateStageMetrics(stageId, attemptId, taskId, SQLAppStatusListener.UNKNOWN_INDEX,
150+
accumUpdates, false)
151+
}
152+
}
153+
154+
override def onTaskStart(event: SparkListenerTaskStart): Unit = {
155+
Option(stageMetrics.get(event.stageId)).foreach { stage =>
156+
if (stage.attemptId == event.stageAttemptId) {
157+
stage.registerTask(event.taskInfo.taskId, event.taskInfo.index)
158+
}
144159
}
145160
}
146161

@@ -165,7 +180,7 @@ class SQLAppStatusListener(
165180
} else {
166181
info.accumulables
167182
}
168-
updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, accums,
183+
updateStageMetrics(event.stageId, event.stageAttemptId, info.taskId, info.index, accums,
169184
info.successful)
170185
}
171186

@@ -181,17 +196,40 @@ class SQLAppStatusListener(
181196

182197
private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
183198
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
184-
val metrics = exec.stages.toSeq
199+
200+
val taskMetrics = exec.stages.toSeq
185201
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
186-
.flatMap(_.taskMetrics.values().asScala)
187-
.flatMap { metrics => metrics.ids.zip(metrics.values) }
188-
189-
val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq)
190-
.filter { case (id, _) => metricTypes.contains(id) }
191-
.groupBy(_._1)
192-
.map { case (id, values) =>
193-
id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2))
202+
.flatMap(_.metricValues())
203+
204+
val allMetrics = new mutable.HashMap[Long, Array[Long]]()
205+
206+
taskMetrics.foreach { case (id, values) =>
207+
val prev = allMetrics.getOrElse(id, null)
208+
val updated = if (prev != null) {
209+
prev ++ values
210+
} else {
211+
values
194212
}
213+
allMetrics(id) = updated
214+
}
215+
216+
exec.driverAccumUpdates.foreach { case (id, value) =>
217+
if (metricTypes.contains(id)) {
218+
val prev = allMetrics.getOrElse(id, null)
219+
val updated = if (prev != null) {
220+
val _copy = Arrays.copyOf(prev, prev.length + 1)
221+
_copy(prev.length) = value
222+
_copy
223+
} else {
224+
Array(value)
225+
}
226+
allMetrics(id) = updated
227+
}
228+
}
229+
230+
val aggregatedMetrics = allMetrics.map { case (id, values) =>
231+
id -> SQLMetrics.stringValue(metricTypes(id), values)
232+
}.toMap
195233

196234
// Check the execution again for whether the aggregated metrics data has been calculated.
197235
// This can happen if the UI is requesting this data, and the onExecutionEnd handler is
@@ -208,43 +246,13 @@ class SQLAppStatusListener(
208246
stageId: Int,
209247
attemptId: Int,
210248
taskId: Long,
249+
taskIdx: Int,
211250
accumUpdates: Seq[AccumulableInfo],
212251
succeeded: Boolean): Unit = {
213252
Option(stageMetrics.get(stageId)).foreach { metrics =>
214-
if (metrics.attemptId != attemptId || metrics.accumulatorIds.isEmpty) {
215-
return
216-
}
217-
218-
val oldTaskMetrics = metrics.taskMetrics.get(taskId)
219-
if (oldTaskMetrics != null && oldTaskMetrics.succeeded) {
220-
return
253+
if (metrics.attemptId == attemptId) {
254+
metrics.updateTaskMetrics(taskId, taskIdx, succeeded, accumUpdates)
221255
}
222-
223-
val updates = accumUpdates
224-
.filter { acc => acc.update.isDefined && metrics.accumulatorIds.contains(acc.id) }
225-
.sortBy(_.id)
226-
227-
if (updates.isEmpty) {
228-
return
229-
}
230-
231-
val ids = new Array[Long](updates.size)
232-
val values = new Array[Long](updates.size)
233-
updates.zipWithIndex.foreach { case (acc, idx) =>
234-
ids(idx) = acc.id
235-
// In a live application, accumulators have Long values, but when reading from event
236-
// logs, they have String values. For now, assume all accumulators are Long and covert
237-
// accordingly.
238-
values(idx) = acc.update.get match {
239-
case s: String => s.toLong
240-
case l: Long => l
241-
case o => throw new IllegalArgumentException(s"Unexpected: $o")
242-
}
243-
}
244-
245-
// TODO: storing metrics by task ID can cause metrics for the same task index to be
246-
// counted multiple times, for example due to speculation or re-attempts.
247-
metrics.taskMetrics.put(taskId, new LiveTaskMetrics(ids, values, succeeded))
248256
}
249257
}
250258

@@ -425,12 +433,76 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity {
425433
}
426434

427435
private class LiveStageMetrics(
428-
val stageId: Int,
429-
var attemptId: Int,
430-
val accumulatorIds: Set[Long],
431-
val taskMetrics: ConcurrentHashMap[Long, LiveTaskMetrics])
432-
433-
private class LiveTaskMetrics(
434-
val ids: Array[Long],
435-
val values: Array[Long],
436-
val succeeded: Boolean)
436+
val attemptId: Int,
437+
val numTasks: Int,
438+
val accumulatorIds: Set[Long]) {
439+
440+
/**
441+
* Mapping of task IDs to their respective index. Note this may contain more elements than the
442+
* stage's number of tasks, if speculative execution is on.
443+
*/
444+
private val taskIndices = new OpenHashMap[Long, Int]()
445+
446+
/** Bit set tracking which indices have been successfully computed. */
447+
private val completedIndices = new mutable.BitSet()
448+
449+
/**
450+
* Task metrics values for the stage. Maps the metric ID to the metric values for each
451+
* index. For each metric ID, there will be the same number of values as the number
452+
* of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
453+
* independent of the actual metric type.
454+
*/
455+
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
456+
457+
def registerTask(taskId: Long, taskIdx: Int): Unit = {
458+
taskIndices.update(taskId, taskIdx)
459+
}
460+
461+
def updateTaskMetrics(
462+
taskId: Long,
463+
eventIdx: Int,
464+
finished: Boolean,
465+
accumUpdates: Seq[AccumulableInfo]): Unit = {
466+
val taskIdx = if (eventIdx == SQLAppStatusListener.UNKNOWN_INDEX) {
467+
if (!taskIndices.contains(taskId)) {
468+
// We probably missed the start event for the task, just ignore it.
469+
return
470+
}
471+
taskIndices(taskId)
472+
} else {
473+
// Here we can recover from a missing task start event. Just register the task again.
474+
registerTask(taskId, eventIdx)
475+
eventIdx
476+
}
477+
478+
if (completedIndices.contains(taskIdx)) {
479+
return
480+
}
481+
482+
accumUpdates
483+
.filter { acc => acc.update.isDefined && accumulatorIds.contains(acc.id) }
484+
.foreach { acc =>
485+
// In a live application, accumulators have Long values, but when reading from event
486+
// logs, they have String values. For now, assume all accumulators are Long and convert
487+
// accordingly.
488+
val value = acc.update.get match {
489+
case s: String => s.toLong
490+
case l: Long => l
491+
case o => throw new IllegalArgumentException(s"Unexpected: $o")
492+
}
493+
494+
val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
495+
metricValues(taskIdx) = value
496+
}
497+
498+
if (finished) {
499+
completedIndices += taskIdx
500+
}
501+
}
502+
503+
def metricValues(): Seq[(Long, Array[Long])] = taskMetrics.asScala.toSeq
504+
}
505+
506+
private object SQLAppStatusListener {
507+
val UNKNOWN_INDEX = -1
508+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
232232
val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
233233
assert(expectedNodeName === actualNodeName)
234234
for ((metricName, metricPredicate) <- expectedMetricsPredicatesMap) {
235-
assert(metricPredicate(actualMetricsMap(metricName)))
235+
assert(metricPredicate(actualMetricsMap(metricName)),
236+
s"$nodeId / '$metricName' (= ${actualMetricsMap(metricName)}) did not match predicate.")
236237
}
237238
}
238239
}

0 commit comments

Comments
 (0)