Skip to content

Commit f3f6174

Browse files
committed
apache#10 performance issue when gc blocks [follow up]
1 parent 4af8848 commit f3f6174

File tree

11 files changed

+88
-36
lines changed

11 files changed

+88
-36
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark
1919

2020
import java.lang.ref.{ReferenceQueue, WeakReference}
2121
import java.util.Collections
22-
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit}
22+
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ExecutorService, ScheduledExecutorService, TimeUnit}
2323

2424
import scala.collection.JavaConverters._
2525

@@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
2828
import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData}
2929
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils}
3030

31+
3132
/**
3233
* Classes that represent cleaning tasks.
3334
*/
@@ -112,6 +113,15 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
112113
private val blockOnShuffleCleanupTasks = sc.conf.getBoolean(
113114
"spark.cleaner.referenceTracking.blocking.shuffle", false)
114115

116+
/**
117+
* The cleaning thread size.
118+
*/
119+
private val cleanupTaskThreads = sc.conf.getInt(
120+
"spark.cleaner.referenceTracking.cleanupThreadNumber", 100)
121+
122+
private val cleanupExecutorPool: ExecutorService =
123+
ThreadUtils.newDaemonFixedThreadPool(cleanupTaskThreads, "cleanup")
124+
115125
@volatile private var stopped = false
116126

117127
/** Attach a listener object to get information of when objects are cleaned. */
@@ -177,33 +187,37 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
177187
/** Keep cleaning RDD, shuffle, and broadcast state. */
178188
private def keepCleaning(): Unit = Utils.tryOrStopSparkContext(sc) {
179189
while (!stopped) {
180-
try {
181-
val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
182-
.map(_.asInstanceOf[CleanupTaskWeakReference])
183-
// Synchronize here to avoid being interrupted on stop()
184-
synchronized {
185-
reference.foreach { ref =>
186-
logDebug("Got cleaning task " + ref.task)
187-
referenceBuffer.remove(ref)
188-
ref.task match {
189-
case CleanRDD(rddId) =>
190-
doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
191-
case CleanShuffle(shuffleId) =>
192-
doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
193-
case CleanBroadcast(broadcastId) =>
194-
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
195-
case CleanAccum(accId) =>
196-
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
197-
case CleanCheckpoint(rddId) =>
198-
doCleanCheckpoint(rddId)
199-
}
190+
Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
191+
.map(_.asInstanceOf[CleanupTaskWeakReference]).foreach {
192+
r =>
193+
referenceBuffer.remove(r)
194+
runtCleanTask(r)
195+
}
196+
}
197+
}
198+
199+
private def runtCleanTask(ref: CleanupTaskWeakReference) = {
200+
cleanupExecutorPool.submit(new Runnable {
201+
override def run(): Unit = {
202+
try {
203+
ref.task match {
204+
case CleanRDD(rddId) =>
205+
doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
206+
case CleanShuffle(shuffleId) =>
207+
doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
208+
case CleanBroadcast(broadcastId) =>
209+
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
210+
case CleanAccum(accId) =>
211+
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
212+
case CleanCheckpoint(rddId) =>
213+
doCleanCheckpoint(rddId)
200214
}
215+
} catch {
216+
case ie: InterruptedException if stopped => // ignore
217+
case e: Exception => logError("Error in cleaning thread", e)
201218
}
202-
} catch {
203-
case ie: InterruptedException if stopped => // ignore
204-
case e: Exception => logError("Error in cleaning thread", e)
205219
}
206-
}
220+
})
207221
}
208222

209223
/** Perform RDD cleanup. */

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ private[spark] object MapOutputTracker extends Logging {
807807
if (arr.length >= minBroadcastSize) {
808808
// Use broadcast instead.
809809
// Important arr(0) is the tag == DIRECT, ignore that while deserializing !
810-
val bcast = broadcastManager.newBroadcast(arr, isLocal)
810+
val bcast = broadcastManager.newBroadcast(arr, isLocal, null)
811811
// toByteArray creates copy, so we can reuse out
812812
out.reset()
813813
out.write(BROADCAST)

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,10 +1486,13 @@ class SparkContext(config: SparkConf) extends Logging {
14861486
assertNotStopped()
14871487
require(!classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass),
14881488
"Can not directly broadcast RDDs; instead, call collect() and broadcast the result.")
1489-
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
1489+
val executionId = getLocalProperty("spark.sql.execution.id")
1490+
val bc = env.broadcastManager.newBroadcast[T](value, isLocal, executionId)
14901491
val callSite = getCallSite
14911492
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
1492-
cleaner.foreach(_.registerBroadcastForCleanup(bc))
1493+
if (executionId == null) {
1494+
cleaner.foreach(_.registerBroadcastForCleanup(bc))
1495+
}
14931496
bc
14941497
}
14951498

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@ package org.apache.spark.broadcast
1919

2020
import java.util.concurrent.atomic.AtomicLong
2121

22+
import scala.collection.mutable.ListBuffer
2223
import scala.reflect.ClassTag
2324

25+
import avro.shaded.com.google.common.collect.Maps
2426
import org.apache.commons.collections.map.{AbstractReferenceMap, ReferenceMap}
2527

2628
import org.apache.spark.{SecurityManager, SparkConf}
2729
import org.apache.spark.internal.Logging
2830

31+
2932
private[spark] class BroadcastManager(
3033
val isDriver: Boolean,
3134
conf: SparkConf,
@@ -34,6 +37,7 @@ private[spark] class BroadcastManager(
3437

3538
private var initialized = false
3639
private var broadcastFactory: BroadcastFactory = null
40+
var cachedBroadcast = Maps.newConcurrentMap[String, ListBuffer[Long]]()
3741

3842
initialize()
3943

@@ -54,12 +58,31 @@ private[spark] class BroadcastManager(
5458

5559
private val nextBroadcastId = new AtomicLong(0)
5660

61+
private[spark] def currentBroadcastId: Long = nextBroadcastId.get()
62+
5763
private[broadcast] val cachedValues = {
5864
new ReferenceMap(AbstractReferenceMap.HARD, AbstractReferenceMap.WEAK)
5965
}
6066

61-
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
62-
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
67+
def cleanBroadCast(executionId: String): Unit = {
68+
if (cachedBroadcast.containsKey(executionId)) {
69+
cachedBroadcast.get(executionId).foreach(broadcastId => unbroadcast(broadcastId, true, false))
70+
cachedBroadcast.remove(executionId)
71+
}
72+
}
73+
74+
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, executionId: String): Broadcast[T] = {
75+
val broadcastId = nextBroadcastId.getAndIncrement()
76+
if (executionId != null) {
77+
if (cachedBroadcast.containsKey(executionId)) {
78+
cachedBroadcast.get(executionId) += broadcastId
79+
} else {
80+
val list = new scala.collection.mutable.ListBuffer[Long]
81+
list += broadcastId
82+
cachedBroadcast.put(executionId, list)
83+
}
84+
}
85+
broadcastFactory.newBroadcast[T](value_, isLocal, broadcastId)
6386
}
6487

6588
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ import org.apache.spark.rdd.RDD
5454
private[spark] class ResultTask[T, U](
5555
stageId: Int,
5656
stageAttemptId: Int,
57-
taskBinary: Broadcast[Array[Byte]],
57+
val taskBinary: Broadcast[Array[Byte]],
5858
partition: Partition,
5959
locs: Seq[TaskLocation],
6060
val outputId: Int,

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ import org.apache.spark.shuffle.ShuffleWriter
5555
private[spark] class ShuffleMapTask(
5656
stageId: Int,
5757
stageAttemptId: Int,
58-
taskBinary: Broadcast[Array[Byte]],
58+
val taskBinary: Broadcast[Array[Byte]],
5959
partition: Partition,
6060
@transient private var locs: Seq[TaskLocation],
6161
localProperties: Properties,

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,13 @@ private[spark] class TaskSetManager(
530530
private def maybeFinishTaskSet() {
531531
if (isZombie && runningTasks == 0) {
532532
sched.taskSetFinished(this)
533+
val broadcastId = taskSet.tasks.head match {
534+
case resultTask: ResultTask[Any, Any] =>
535+
resultTask.taskBinary.id
536+
case shuffleMapTask: ShuffleMapTask =>
537+
shuffleMapTask.taskBinary.id
538+
}
539+
SparkEnv.get.broadcastManager.unbroadcast(broadcastId, true, false)
533540
if (tasksSuccessful == numTasks) {
534541
blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet(
535542
taskSet.stageId,

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import scala.util.Random
3333
import scala.util.control.NonFatal
3434

3535
import com.codahale.metrics.{MetricRegistry, MetricSet}
36-
import com.google.common.io.CountingOutputStream
3736

3837
import org.apache.spark._
3938
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,12 @@ class BlockManagerMasterEndpoint(
202202
0 // zero blocks were removed
203203
}
204204
}.toSeq
205-
205+
val blocksToRemove = blockLocations.keySet().asScala
206+
.collect {
207+
case broadcastId@BroadcastBlockId(`broadcastId`, _) =>
208+
broadcastId
209+
}
210+
blocksToRemove.foreach(blockLocations.remove)
206211
Future.sequence(futures)
207212
}
208213

sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
2020
import java.util.concurrent.ConcurrentHashMap
2121
import java.util.concurrent.atomic.AtomicLong
2222

23-
import org.apache.spark.SparkContext
23+
import org.apache.spark.SparkEnv
2424
import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
2626

@@ -84,6 +84,7 @@ object SQLExecution {
8484
} finally {
8585
executionIdToQueryExecution.remove(executionId)
8686
sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
87+
SparkEnv.get.broadcastManager.cleanBroadCast(executionId.toString)
8788
}
8889
}
8990

0 commit comments

Comments
 (0)