diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 33bdc778e2371..6e449e4dc1112 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -93,6 +93,10 @@ private[spark] abstract class Task[T]( require(cpus > 0, "CPUs per task should be > 0") + // Use the blockManager at start of the task through out the task - particularly in + // case of local mode, a SparkEnv can be initialized when spark context is restarted + // and we want to ensure the right env and block manager is used (given lazy initialization of + // block manager) val blockManager = SparkEnv.get.blockManager blockManager.registerTask(taskAttemptId) // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a3947493f2831..d08e75733abfd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties import java.util.concurrent.Semaphore -import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer @@ -32,7 +32,7 @@ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.METRICS_CONF -import org.apache.spark.memory.{MemoryManager, TaskMemoryManager} +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD @@ -684,7 +684,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(isFailed) } - test("Ensure the right block manager is used to unroll memory for task") { + test("SPARK-46947: ensure the correct block manager is used to unroll memory for task") { import BlockManagerValidationPlugin._ BlockManagerValidationPlugin.resetState() @@ -697,45 +697,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark override def compute(split: Partition, context: TaskContext): Iterator[String] = { context.addTaskCompletionListener(new TaskCompletionListener { override def onTaskCompletion(context: TaskContext): Unit = { - var done = false - while (!done) { - try { - releaseTaskSem.acquire(1) - done = true - } catch { - case iEx: InterruptedException => - // ignore thread interruption - logInfo("Ignoring thread interruption", iEx) - } + try { + releaseTaskSem.acquire() + } catch { + case _: InterruptedException => + // ignore thread interruption } } }) - taskMemoryManager.set(SparkEnv.get.blockManager.memoryManager) taskStartedSem.release() - Iterator("hi") + Iterator.empty } } - // submit the job, but dont block this thread + // submit the job, but don't block this thread rdd.collectAsync() // wait for task to start - taskStartedSem.acquire(1) + taskStartedSem.acquire() sc.stop() assert(sc.isStopped) - // create a new SparkContext + // create a new SparkContext which will be blocked for certain amount of time + // during initializing the driver plugin below val conf = new SparkConf() conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName) - BlockManagerValidationPlugin.threadLocalState.set( - () => { - val tmm = taskMemoryManager.get() - tmm.synchronized { - releaseTaskSem.release(1) - tmm.wait() - } - Thread.sleep(2500) - } - ) sc = new SparkContext("local", "test", conf) } } @@ -749,32 +734,28 @@ private object TaskContextSuite extends Logging { } class BlockManagerValidationPlugin extends SparkPlugin { - override def driverPlugin(): DriverPlugin = { new DriverPlugin() { - // We dont really do anything - other than notifying that plugin creation has completed - // and then wait for a while - Option(BlockManagerValidationPlugin.threadLocalState.get()).foreach(_.apply()) + // does nothing but block the current thread for certain time for the task thread + // to progress and reproduce the issue. + BlockManagerValidationPlugin.releaseTaskSem.release() + Thread.sleep(2500) } } override def executorPlugin(): ExecutorPlugin = { new ExecutorPlugin() { - // nothing to see here + // do nothing } } } object BlockManagerValidationPlugin { - val threadLocalState = new ThreadLocal[() => Unit]() - val releaseTaskSem = new Semaphore(0) - val taskMemoryManager = new AtomicReference[MemoryManager](null) val taskStartedSem = new Semaphore(0) def resetState(): Unit = { releaseTaskSem.drainPermits() taskStartedSem.drainPermits() - taskMemoryManager.set(null) } }