From 8d6f7f195688729ddf009d1d4d00a1b08a440eac Mon Sep 17 00:00:00 2001 From: beliefer Date: Tue, 6 Feb 2024 17:23:03 +0800 Subject: [PATCH] [SPARK-46895][CORE] Replace Timer with single thread scheduled executor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR propose to replace `Timer` with single thread scheduled executor. ### Why are the changes needed? The javadoc recommends `ScheduledThreadPoolExecutor` instead of `Timer`. ![屏幕快照 2024-01-12 下午12 47 57](https://github.com/apache/spark/assets/8486025/4fc5ed61-6bb9-4768-915a-ad919a067d04) This change based on the following two points. **System time sensitivity** Timer scheduling is based on the absolute time of the operating system and is sensitive to the operating system's time. Once the operating system's time changes, Timer scheduling is no longer precise. The scheduled Thread Pool Executor scheduling is based on relative time and is not affected by changes in operating system time. **Are anomalies captured** Timer does not capture exceptions thrown by Timer Tasks, and in addition, Timer is single threaded. Once a scheduling task encounters an exception, the entire thread will terminate and other tasks that need to be scheduled will no longer be executed. The scheduled Thread Pool Executor implements scheduling functions based on a thread pool. After a task throws an exception, other tasks can still execute normally. ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? GA tests. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #44718 from beliefer/replace-timer-with-threadpool. Authored-by: beliefer Signed-off-by: yangjie01 (cherry picked from commit 5d5b3a54b7b5fb4308fe40da696ba805c72983fc) --- .../org/apache/spark/BarrierCoordinator.scala | 11 +++++++---- .../org/apache/spark/BarrierTaskContext.scala | 14 ++++++++++---- .../spark/scheduler/TaskSchedulerImpl.scala | 15 ++++++++------- .../org/apache/spark/ui/ConsoleProgressBar.scala | 11 ++++------- .../org/apache/spark/util/ThreadUtils.scala | 16 ++++++++++++++-- .../apache/spark/launcher/LauncherServer.java | 8 ++++---- 6 files changed, 47 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 8ffccdf664b2d..038c62141dab9 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,8 +17,8 @@ package org.apache.spark -import java.util.{Timer, TimerTask} -import java.util.concurrent.ConcurrentHashMap +import java.util.TimerTask +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.function.Consumer import scala.collection.mutable.{ArrayBuffer, HashSet} @@ -26,6 +26,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} +import org.apache.spark.util.ThreadUtils /** * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus @@ -51,7 +52,8 @@ private[spark] class BarrierCoordinator( // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to // fetch result, we shall fix the issue. - private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + private lazy val timer = ThreadUtils.newSingleThreadScheduledExecutor( + "BarrierCoordinator barrier epoch increment timer") // Listen to StageCompleted event, clear corresponding ContextBarrierState. private val listener = new SparkListener { @@ -77,6 +79,7 @@ private[spark] class BarrierCoordinator( states.forEachValue(1, clearStateConsumer) states.clear() listenerBus.removeListener(listener) + ThreadUtils.shutdown(timer) } finally { super.onStop() } @@ -168,7 +171,7 @@ private[spark] class BarrierCoordinator( // we may timeout for the sync. if (requesters.isEmpty) { initTimerTask(this) - timer.schedule(timerTask, timeoutInSecs * 1000) + timer.schedule(timerTask, timeoutInSecs, TimeUnit.SECONDS) } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 94ba3fe64a859..f99b25f3bc127 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,7 +17,8 @@ package org.apache.spark -import java.util.{Properties, Timer, TimerTask} +import java.util.{Properties, TimerTask} +import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit} import scala.collection.JavaConverters._ import scala.concurrent.duration._ @@ -69,8 +70,8 @@ class BarrierTaskContext private[spark] ( s"current barrier epoch is $barrierEpoch.") } } - // Log the update of global sync every 60 seconds. - timer.schedule(timerTask, 60000, 60000) + // Log the update of global sync every 1 minute. + timer.scheduleAtFixedRate(timerTask, 1, 1, TimeUnit.MINUTES) try { val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]]( @@ -282,6 +283,11 @@ object BarrierTaskContext { @Since("2.4.0") def get(): BarrierTaskContext = TaskContext.get().asInstanceOf[BarrierTaskContext] - private val timer = new Timer("Barrier task timer for barrier() calls.") + private val timer = { + val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "Barrier task timer for barrier() calls.") + assert(executor.isInstanceOf[ScheduledThreadPoolExecutor]) + executor.asInstanceOf[ScheduledThreadPoolExecutor] + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 5e9716dfcfe90..72461339fe946 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{Timer, TimerTask} +import java.util.TimerTask import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicLong @@ -135,7 +135,8 @@ private[spark] class TaskSchedulerImpl( @volatile private var hasReceivedTask = false @volatile private var hasLaunchedTask = false - private val starvationTimer = new Timer("task-starvation-timer", true) + private val starvationTimer = ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "task-starvation-timer") // Incrementing task IDs val nextTaskId = new AtomicLong(0) @@ -166,7 +167,7 @@ private[spark] class TaskSchedulerImpl( protected val executorIdToHost = new HashMap[String, String] - private val abortTimer = new Timer("task-abort-timer", true) + private val abortTimer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-abort-timer") // Exposed for testing val unschedulableTaskSetToExpiryTime = new HashMap[TaskSetManager, Long] @@ -282,7 +283,7 @@ private[spark] class TaskSchedulerImpl( this.cancel() } } - }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS) + }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS, TimeUnit.MILLISECONDS) } hasReceivedTask = true } @@ -776,7 +777,7 @@ private[spark] class TaskSchedulerImpl( logInfo(s"Waiting for $timeout ms for completely " + s"excluded task to be schedulable again before aborting stage ${taskSet.stageId}.") abortTimer.schedule( - createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout) + createUnschedulableTaskSetAbortTimer(taskSet, taskIndex), timeout, TimeUnit.MILLISECONDS) } private def createUnschedulableTaskSetAbortTimer( @@ -1002,8 +1003,8 @@ private[spark] class TaskSchedulerImpl( barrierCoordinator.stop() } } - starvationTimer.cancel() - abortTimer.cancel() + ThreadUtils.shutdown(starvationTimer) + ThreadUtils.shutdown(abortTimer) } override def defaultParallelism(): Int = backend.defaultParallelism() diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 63c82d0665d43..0ab7040e12d36 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -17,14 +17,13 @@ package org.apache.spark.ui -import java.util.concurrent.{Executors, TimeUnit} - -import com.google.common.util.concurrent.ThreadFactoryBuilder +import java.util.concurrent.TimeUnit import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI._ import org.apache.spark.status.api.v1.StageData +import org.apache.spark.util.ThreadUtils /** * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the @@ -48,9 +47,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { private var lastProgressBar = "" // Schedule a refresh thread to run periodically - private val threadFactory = - new ThreadFactoryBuilder().setDaemon(true).setNameFormat("refresh progress").build() - private val timer = Executors.newSingleThreadScheduledExecutor(threadFactory) + private val timer = ThreadUtils.newDaemonSingleThreadScheduledExecutor("refresh progress") timer.scheduleAtFixedRate( () => refresh(), firstDelayMSec, updatePeriodMSec, TimeUnit.MILLISECONDS) @@ -124,5 +121,5 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { * Tear down the timer thread. The timer thread is a GC root, and it retains the entire * SparkContext if it's not terminated. */ - def stop(): Unit = timer.shutdown() + def stop(): Unit = ThreadUtils.shutdown(timer) } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 2d3d6ec89ffbd..fb9723b9da6af 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -160,7 +160,7 @@ private[spark] object ThreadUtils { } /** - * Wrapper over newSingleThreadExecutor. + * Wrapper over newFixedThreadPool with single daemon thread. */ def newDaemonSingleThreadExecutor(threadName: String): ThreadPoolExecutor = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() @@ -189,7 +189,7 @@ private[spark] object ThreadUtils { } /** - * Wrapper over ScheduledThreadPoolExecutor. + * Wrapper over ScheduledThreadPoolExecutor the pool with daemon threads. */ def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() @@ -200,6 +200,18 @@ private[spark] object ThreadUtils { executor } + /** + * Wrapper over ScheduledThreadPoolExecutor the pool with non-daemon threads. + */ + def newSingleThreadScheduledExecutor(threadName: String): ScheduledThreadPoolExecutor = { + val threadFactory = new ThreadFactoryBuilder().setNameFormat(threadName).build() + val executor = new ScheduledThreadPoolExecutor(1, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor + } + /** * Wrapper over ScheduledThreadPoolExecutor. */ diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 125205f416d35..48d0c5fb92049 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -225,7 +225,7 @@ private void acceptConnections() { try { while (running) { final Socket client = server.accept(); - TimerTask timeout = new TimerTask() { + TimerTask timerTask = new TimerTask() { @Override public void run() { LOG.warning("Timed out waiting for hello message from client."); @@ -236,7 +236,7 @@ public void run() { } } }; - ServerConnection clientConnection = new ServerConnection(client, timeout); + ServerConnection clientConnection = new ServerConnection(client, timerTask); Thread clientThread = factory.newThread(clientConnection); clientConnection.setConnectionThread(clientThread); synchronized (clients) { @@ -247,9 +247,9 @@ public void run() { // 0 is used for testing to avoid issues with clock resolution / thread scheduling, // and force an immediate timeout. if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, timeoutMs); + timeoutTimer.schedule(timerTask, timeoutMs); } else { - timeout.run(); + timerTask.run(); } clientThread.start();