Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-46895][CORE] Replace Timer with single thread scheduled executor #44718

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@

package org.apache.spark

import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.TimerTask
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.function.Consumer

import scala.collection.mutable.{ArrayBuffer, HashSet}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
import org.apache.spark.util.ThreadUtils

/**
* For each barrier stage attempt, only at most one barrier() call can be active at any time, thus
Expand All @@ -51,7 +52,8 @@ private[spark] class BarrierCoordinator(

// TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
// fetch result, we shall fix the issue.
private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer")
private lazy val timer = ThreadUtils.newSingleThreadScheduledExecutor(
srowen marked this conversation as resolved.
Show resolved Hide resolved
"BarrierCoordinator barrier epoch increment timer")

// Listen to StageCompleted event, clear corresponding ContextBarrierState.
private val listener = new SparkListener {
Expand All @@ -77,6 +79,7 @@ private[spark] class BarrierCoordinator(
states.forEachValue(1, clearStateConsumer)
states.clear()
listenerBus.removeListener(listener)
ThreadUtils.shutdown(timer)
} finally {
super.onStop()
}
Expand Down Expand Up @@ -168,7 +171,7 @@ private[spark] class BarrierCoordinator(
// we may timeout for the sync.
if (requesters.isEmpty) {
initTimerTask(this)
timer.schedule(timerTask, timeoutInSecs * 1000)
timer.schedule(timerTask, timeoutInSecs, TimeUnit.SECONDS)
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark

import java.util.{Properties, Timer, TimerTask}
import java.util.{Properties, TimerTask}
import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit}

import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -70,8 +71,8 @@ class BarrierTaskContext private[spark] (
s"current barrier epoch is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
// Log the update of global sync every 1 minute.
timer.scheduleAtFixedRate(timerTask, 1, 1, TimeUnit.MINUTES)

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
Expand Down Expand Up @@ -283,6 +284,11 @@ object BarrierTaskContext {
@Since("2.4.0")
def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext]

private val timer = new Timer("Barrier task timer for barrier() calls.")
private val timer = {
val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"Barrier task timer for barrier() calls.")
assert(executor.isInstanceOf[ScheduledThreadPoolExecutor])
executor.asInstanceOf[ScheduledThreadPoolExecutor]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.scheduler

import java.nio.ByteBuffer
import java.util.{Timer, TimerTask}
import java.util.TimerTask
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong

Expand Down Expand Up @@ -135,7 +135,8 @@ private[spark] class TaskSchedulerImpl(

@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer("task-starvation-timer", true)
private val starvationTimer = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
"task-starvation-timer")

// Incrementing task IDs
val nextTaskId = new AtomicLong(0)
Expand Down Expand Up @@ -166,7 +167,7 @@ private[spark] class TaskSchedulerImpl(

protected val executorIdToHost = new HashMap[String, String]

private val abortTimer = new Timer("task-abort-timer", true)
private val abortTimer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-abort-timer")
// Exposed for testing
val unschedulableTaskSetToExpiryTime = new HashMap[TaskSetManager, Long]

Expand Down Expand Up @@ -282,7 +283,7 @@ private[spark] class TaskSchedulerImpl(
this.cancel()
}
}
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
hasReceivedTask = true
}
Expand Down Expand Up @@ -737,7 +738,7 @@ private[spark] class TaskSchedulerImpl(
logInfo(s"Waiting for $timeout ms for completely " +
s"excluded task to be schedulable again before aborting stage ${taskSet.stageId}.")
abortTimer.schedule(
createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout)
createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout, TimeUnit.MILLISECONDS)
}

private def createUnschedulableTaskSetAbortTimer(
Expand Down Expand Up @@ -963,8 +964,8 @@ private[spark] class TaskSchedulerImpl(
barrierCoordinator.stop()
}
}
starvationTimer.cancel()
abortTimer.cancel()
ThreadUtils.shutdown(starvationTimer)
ThreadUtils.shutdown(abortTimer)
}

override def defaultParallelism(): Int = backend.defaultParallelism()
Expand Down
11 changes: 4 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

package org.apache.spark.ui

import java.util.concurrent.{Executors, TimeUnit}

import com.google.common.util.concurrent.ThreadFactoryBuilder
import java.util.concurrent.TimeUnit

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI._
import org.apache.spark.status.api.v1.StageData
import org.apache.spark.util.ThreadUtils

/**
* ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the
Expand All @@ -48,9 +47,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
private var lastProgressBar = ""

// Schedule a refresh thread to run periodically
private val threadFactory =
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("refresh progress").build()
private val timer = Executors.newSingleThreadScheduledExecutor(threadFactory)
private val timer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("refresh progress")
timer.scheduleAtFixedRate(
() => refresh(), firstDelayMSec, updatePeriodMSec, TimeUnit.MILLISECONDS)

Expand Down Expand Up @@ -124,5 +121,5 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
* Tear down the timer thread. The timer thread is a GC root, and it retains the entire
* SparkContext if it's not terminated.
*/
def stop(): Unit = timer.shutdown()
def stop(): Unit = ThreadUtils.shutdown(timer)
}
16 changes: 14 additions & 2 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private[spark] object ThreadUtils {
}

/**
* Wrapper over newSingleThreadExecutor.
* Wrapper over newFixedThreadPool with single daemon thread.
*/
def newDaemonSingleThreadExecutor(threadName: String): ThreadPoolExecutor = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
Expand Down Expand Up @@ -189,7 +189,7 @@ private[spark] object ThreadUtils {
}

/**
* Wrapper over ScheduledThreadPoolExecutor.
* Wrapper over ScheduledThreadPoolExecutor the pool with daemon threads.
*/
def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = {
val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build()
Expand All @@ -200,6 +200,18 @@ private[spark] object ThreadUtils {
executor
}

/**
* Wrapper over ScheduledThreadPoolExecutor the pool with non-daemon threads.
*/
def newSingleThreadScheduledExecutor(threadName: String): ScheduledThreadPoolExecutor = {
val threadFactory = new ThreadFactoryBuilder().setNameFormat(threadName).build()
val executor = new ScheduledThreadPoolExecutor(1, threadFactory)
// By default, a cancelled task is not automatically removed from the work queue until its delay
// elapses. We have to enable it manually.
executor.setRemoveOnCancelPolicy(true)
executor
}

/**
* Wrapper over ScheduledThreadPoolExecutor.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private void acceptConnections() {
try {
while (running) {
final Socket client = server.accept();
TimerTask timeout = new TimerTask() {
TimerTask timerTask = new TimerTask() {
@Override
public void run() {
LOG.warning("Timed out waiting for hello message from client.");
Expand All @@ -236,7 +236,7 @@ public void run() {
}
}
};
ServerConnection clientConnection = new ServerConnection(client, timeout);
ServerConnection clientConnection = new ServerConnection(client, timerTask);
Thread clientThread = factory.newThread(clientConnection);
clientConnection.setConnectionThread(clientThread);
synchronized (clients) {
Expand All @@ -247,9 +247,9 @@ public void run() {
// 0 is used for testing to avoid issues with clock resolution / thread scheduling,
// and force an immediate timeout.
if (timeoutMs > 0) {
timeoutTimer.schedule(timeout, timeoutMs);
timeoutTimer.schedule(timerTask, timeoutMs);
} else {
timeout.run();
timerTask.run();
}

clientThread.start();
Expand Down