Skip to content

Commit

Permalink
cleanup block manager test
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Feb 26, 2024
1 parent 59a2acb commit 0f41c4b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 37 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ private[spark] abstract class Task[T](

require(cpus > 0, "CPUs per task should be > 0")

// Use the blockManager at start of the task through out the task - particularly in
// case of local mode, a SparkEnv can be initialized when spark context is restarted
// and we want to ensure the right env and block manager is used (given lazy initialization of
// block manager)
val blockManager = SparkEnv.get.blockManager
blockManager.registerTask(taskAttemptId)
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer

Expand All @@ -32,7 +32,7 @@ import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.METRICS_CONF
import org.apache.spark.memory.{MemoryManager, TaskMemoryManager}
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -684,7 +684,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(isFailed)
}

test("Ensure the right block manager is used to unroll memory for task") {
test("SPARK-46947: ensure the correct block manager is used to unroll memory for task") {
import BlockManagerValidationPlugin._
BlockManagerValidationPlugin.resetState()

Expand All @@ -697,45 +697,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
override def compute(split: Partition, context: TaskContext): Iterator[String] = {
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
var done = false
while (!done) {
try {
releaseTaskSem.acquire(1)
done = true
} catch {
case iEx: InterruptedException =>
// ignore thread interruption
logInfo("Ignoring thread interruption", iEx)
}
try {
releaseTaskSem.acquire()
} catch {
case _: InterruptedException =>
// ignore thread interruption
}
}
})
taskMemoryManager.set(SparkEnv.get.blockManager.memoryManager)
taskStartedSem.release()
Iterator("hi")
Iterator.empty
}
}
// submit the job, but dont block this thread
// submit the job, but don't block this thread
rdd.collectAsync()
// wait for task to start
taskStartedSem.acquire(1)
taskStartedSem.acquire()

sc.stop()
assert(sc.isStopped)

// create a new SparkContext
// create a new SparkContext which will be blocked for certain amount of time
// during initializing the driver plugin below
val conf = new SparkConf()
conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName)
BlockManagerValidationPlugin.threadLocalState.set(
() => {
val tmm = taskMemoryManager.get()
tmm.synchronized {
releaseTaskSem.release(1)
tmm.wait()
}
Thread.sleep(2500)
}
)
sc = new SparkContext("local", "test", conf)
}
}
Expand All @@ -749,32 +734,28 @@ private object TaskContextSuite extends Logging {
}

class BlockManagerValidationPlugin extends SparkPlugin {

override def driverPlugin(): DriverPlugin = {
new DriverPlugin() {
// We dont really do anything - other than notifying that plugin creation has completed
// and then wait for a while
Option(BlockManagerValidationPlugin.threadLocalState.get()).foreach(_.apply())
// does nothing but block the current thread for certain time for the task thread
// to progress and reproduce the issue.
BlockManagerValidationPlugin.releaseTaskSem.release()
Thread.sleep(2500)
}
}
override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin() {
// nothing to see here
// do nothing
}
}
}

object BlockManagerValidationPlugin {
val threadLocalState = new ThreadLocal[() => Unit]()

val releaseTaskSem = new Semaphore(0)
val taskMemoryManager = new AtomicReference[MemoryManager](null)
val taskStartedSem = new Semaphore(0)

def resetState(): Unit = {
releaseTaskSem.drainPermits()
taskStartedSem.drainPermits()
taskMemoryManager.set(null)
}
}

Expand Down

0 comments on commit 0f41c4b

Please sign in to comment.