Skip to content

Commit e635168

Browse files
Nathan KronenfeldJoshRosen
authored andcommitted
[SPARK-4772] Clear local copies of accumulators as soon as we're done with them
Accumulators keep thread-local copies of themselves. These copies were only cleared at the beginning of a task. This meant that (a) the memory they used was tied up until the next task ran on that thread, and (b) if a thread died, the memory it had used for accumulators was locked up forever on that worker. This PR clears the thread-local copies of accumulators at the end of each task, in the tasks finally block, to make sure they are cleaned up between tasks. It also stores them in a ThreadLocal object, so that if, for some reason, the thread dies, any memory they are using at the time should be freed up. Author: Nathan Kronenfeld <nkronenfeld@oculusinfo.com> Closes #3570 from nkronenfeld/Accumulator-Improvements and squashes the following commits: a581f3f [Nathan Kronenfeld] Change Accumulators to private[spark] instead of adding mima exclude to get around false positive in mima tests b6c2180 [Nathan Kronenfeld] Include MiMa exclude as per build error instructions - this version incompatibility should be irrelevent, as it will only surface if a master is talking to a worker running a different version of spark. 537baad [Nathan Kronenfeld] Fuller refactoring as intended, incorporating JR's suggestions for ThreadLocal localAccums, and keeping clear(), but also calling it in tasks' finally block, rather than just at the beginning of the task. 39a82f2 [Nathan Kronenfeld] Clear local copies of accumulators as soon as we're done with them (cherry picked from commit 94b377f) Signed-off-by: Josh Rosen <joshrosen@databricks.com>
1 parent 0ebbccb commit e635168

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark
1919

2020
import java.io.{ObjectInputStream, Serializable}
2121
import java.util.concurrent.atomic.AtomicLong
22+
import java.lang.ThreadLocal
2223

2324
import scala.collection.generic.Growable
2425
import scala.collection.mutable.Map
@@ -248,10 +249,12 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] {
248249

249250
// TODO: The multi-thread support in accumulators is kind of lame; check
250251
// if there's a more intuitive way of doing it right
251-
private object Accumulators {
252+
private[spark] object Accumulators {
252253
// TODO: Use soft references? => need to make readObject work properly then
253254
val originals = Map[Long, Accumulable[_, _]]()
254-
val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
255+
val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
256+
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
257+
}
255258
var lastId: Long = 0
256259

257260
def newId(): Long = synchronized {
@@ -263,22 +266,21 @@ private object Accumulators {
263266
if (original) {
264267
originals(a.id) = a
265268
} else {
266-
val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
267-
accums(a.id) = a
269+
localAccums.get()(a.id) = a
268270
}
269271
}
270272

271273
// Clear the local (non-original) accumulators for the current thread
272274
def clear() {
273275
synchronized {
274-
localAccums.remove(Thread.currentThread)
276+
localAccums.get.clear
275277
}
276278
}
277279

278280
// Get the values of the local accumulators for the current thread (by ID)
279281
def values: Map[Long, Any] = synchronized {
280282
val ret = Map[Long, Any]()
281-
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
283+
for ((id, accum) <- localAccums.get) {
282284
ret(id) = accum.localValue
283285
}
284286
return ret

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ private[spark] class Executor(
172172
val startGCTime = gcTime
173173

174174
try {
175-
Accumulators.clear()
176175
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
177176
updateDependencies(taskFiles, taskJars)
178177
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
@@ -278,6 +277,8 @@ private[spark] class Executor(
278277
env.shuffleMemoryManager.releaseMemoryForThisThread()
279278
// Release memory used by this thread for unrolling blocks
280279
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
280+
// Release memory used by this thread for accumulators
281+
Accumulators.clear()
281282
runningTasks.remove(taskId)
282283
}
283284
}

0 commit comments

Comments
 (0)