Skip to content

Commit 0b72660

Browse files
committed
Initial WIP example of supporing globally named accumulators.
1 parent 9d5ecf8 commit 0b72660

File tree

7 files changed

+75
-7
lines changed

7 files changed

+75
-7
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ class Accumulable[R, T] (
5151

5252
Accumulators.register(this, true)
5353

54+
/** A name for this accumulator / accumulable for display in Spark's UI.
55+
* Note that names must be unique within a SparkContext. */
56+
def name: String = s"accumulator_$id"
57+
58+
/** Whether to display this accumulator in the web UI. */
59+
def display: Boolean = true
60+
5461
/**
5562
* Add more data to this accumulator / accumulable
5663
* @param term the data to add
@@ -219,8 +226,12 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
219226
* @param param helper object defining how to add elements of type `T`
220227
* @tparam T result type
221228
*/
222-
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T])
223-
extends Accumulable[T,T](initialValue, param)
229+
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], _name: String, _display: Boolean)
230+
extends Accumulable[T,T](initialValue, param) {
231+
override def name = if (_name.eq(null)) s"accumulator_$id" else _name
232+
override def display = _display
233+
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, null, true)
234+
}
224235

225236
/**
226237
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,16 @@ class SparkContext(config: SparkConf) extends Logging {
757757
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
758758
new Accumulator(initialValue, param)
759759

760+
/**
761+
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
762+
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
763+
*
764+
* This version adds a custom name to the accumulator for display in the Spark UI.
765+
*/
766+
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = {
767+
new Accumulator(initialValue, param, name, true)
768+
}
769+
760770
/**
761771
* Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
762772
* with `+=`. Only the driver can access the accumuable's `value`.

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,10 @@ class DAGScheduler(
791791
val task = event.task
792792
val stageId = task.stageId
793793
val taskType = Utils.getFormattedClassName(task)
794-
listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
795-
event.taskMetrics))
794+
796795
if (!stageIdToStage.contains(task.stageId)) {
796+
listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
797+
event.taskMetrics))
797798
// Skip all the actions if the stage has been cancelled.
798799
return
799800
}
@@ -809,12 +810,24 @@ class DAGScheduler(
809810
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
810811
runningStages -= stage
811812
}
813+
812814
event.reason match {
813815
case Success =>
814816
logInfo("Completed " + task)
815817
if (event.accumUpdates != null) {
816818
// TODO: fail the stage if the accumulator update fails...
817819
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
820+
event.accumUpdates.foreach { case (id, partialValue) =>
821+
val acc = Accumulators.originals(id)
822+
val name = acc.name
823+
// To avoid UI cruft, ignore cases where value wasn't updated
824+
if (partialValue != acc.zero) {
825+
val stringPartialValue = s"${partialValue}"
826+
val stringValue = s"${acc.value}"
827+
stageToInfos(stage).accumulatorValues(name) = stringValue
828+
event.taskInfo.accumValues += ((name, stringPartialValue))
829+
}
830+
}
818831
}
819832
pendingTasks(stage) -= task
820833
task match {
@@ -945,6 +958,8 @@ class DAGScheduler(
945958
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
946959
// will abort the job.
947960
}
961+
listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
962+
event.taskMetrics))
948963
submitWaitingStages()
949964
}
950965

core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import scala.collection.mutable.HashMap
21+
import scala.collection.mutable.Map
22+
2023
import org.apache.spark.annotation.DeveloperApi
2124
import org.apache.spark.storage.RDDInfo
2225

@@ -37,6 +40,8 @@ class StageInfo(
3740
var completionTime: Option[Long] = None
3841
/** If the stage failed, the reason why. */
3942
var failureReason: Option[String] = None
43+
/** Terminal values of accumulables updated during this stage. */
44+
val accumulatorValues: Map[String, String] = HashMap[String, String]()
4045

4146
def stageFailed(reason: String) {
4247
failureReason = Some(reason)

core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import scala.collection.mutable.ListBuffer
21+
2022
import org.apache.spark.annotation.DeveloperApi
2123

2224
/**
@@ -41,6 +43,11 @@ class TaskInfo(
4143
*/
4244
var gettingResultTime: Long = 0
4345

46+
/**
47+
* Terminal values of accumulables updated during this task.
48+
*/
49+
val accumValues = ListBuffer[(String, String)]()
50+
4451
/**
4552
* The time when the task has completed successfully (including the time to remotely fetch
4653
* results, if necessary).

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ui.jobs
1919

20-
import scala.collection.mutable.{HashMap, ListBuffer}
20+
import scala.collection.mutable.{HashMap, ListBuffer, Map}
2121

2222
import org.apache.spark._
2323
import org.apache.spark.annotation.DeveloperApi
@@ -48,6 +48,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
4848

4949
// TODO: Should probably consolidate all following into a single hash map.
5050
val stageIdToTime = HashMap[Int, Long]()
51+
val stageIdToAccumulables = HashMap[Int, Map[String, String]]()
5152
val stageIdToInputBytes = HashMap[Int, Long]()
5253
val stageIdToShuffleRead = HashMap[Int, Long]()
5354
val stageIdToShuffleWrite = HashMap[Int, Long]()
@@ -73,6 +74,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
7374
val stageId = stage.stageId
7475
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
7576
poolToActiveStages(stageIdToPool(stageId)).remove(stageId)
77+
78+
val accumulables = stageIdToAccumulables.getOrElseUpdate(stageId, HashMap[String, String]())
79+
stageCompleted.stageInfo.accumulatorValues.foreach { case (name, value) =>
80+
accumulables(name) = value
81+
}
82+
7683
activeStages.remove(stageId)
7784
if (stage.failureReason.isEmpty) {
7885
completedStages += stage
@@ -89,6 +96,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
8996
val toRemove = math.max(retainedStages / 10, 1)
9097
stages.take(toRemove).foreach { s =>
9198
stageIdToTime.remove(s.stageId)
99+
stageIdToAccumulables.remove(s.stageId)
92100
stageIdToInputBytes.remove(s.stageId)
93101
stageIdToShuffleRead.remove(s.stageId)
94102
stageIdToShuffleWrite.remove(s.stageId)
@@ -147,6 +155,11 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
147155
val info = taskEnd.taskInfo
148156

149157
if (info != null) {
158+
val accumulables = stageIdToAccumulables.getOrElseUpdate(sid, HashMap[String, String]())
159+
info.accumValues.map { case (name, value) =>
160+
accumulables(name) = value
161+
}
162+
150163
// create executor summary map if necessary
151164
val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid,
152165
op = new HashMap[String, ExecutorSummary]())

core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
2020
import java.util.Date
2121
import javax.servlet.http.HttpServletRequest
2222

23-
import scala.xml.Node
23+
import scala.xml.{Unparsed, Node}
2424

2525
import org.apache.spark.ui.{WebUIPage, UIUtils}
2626
import org.apache.spark.util.{Utils, Distribution}
@@ -57,6 +57,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
5757
val memoryBytesSpilled = listener.stageIdToMemoryBytesSpilled.getOrElse(stageId, 0L)
5858
val diskBytesSpilled = listener.stageIdToDiskBytesSpilled.getOrElse(stageId, 0L)
5959
val hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0
60+
val accumulables = listener.stageIdToAccumulables(stageId)
6061

6162
var activeTime = 0L
6263
val now = System.currentTimeMillis
@@ -102,10 +103,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
102103
</ul>
103104
</div>
104105
// scalastyle:on
106+
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
107+
def accumulableRow(acc: (String, String)) = <tr><td>{acc._1}</td><td>{acc._2}</td></tr>
108+
val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.toSeq)
109+
105110
val taskHeaders: Seq[String] =
106111
Seq(
107112
"Index", "ID", "Attempt", "Status", "Locality Level", "Executor",
108-
"Launch Time", "Duration", "GC Time") ++
113+
"Launch Time", "Duration", "GC Time", "Accumulators") ++
109114
{if (hasInput) Seq("Input") else Nil} ++
110115
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
111116
{if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
@@ -217,6 +222,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
217222
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
218223
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
219224
<h4>Aggregated Metrics by Executor</h4> ++ executorTable.toNodeSeq ++
225+
<h4>Accumulators</h4> ++ accumulableTable ++
220226
<h4>Tasks</h4> ++ taskTable
221227

222228
UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId),
@@ -283,6 +289,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
283289
<td sorttable_customkey={gcTime.toString}>
284290
{if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""}
285291
</td>
292+
<td>{Unparsed(info.accumValues.map{ case (k, v) => s"$k += $v" }.mkString("<br/>"))}</td>
286293
<!--
287294
TODO: Add this back after we add support to hide certain columns.
288295
<td sorttable_customkey={serializationTime.toString}>

0 commit comments

Comments
 (0)