Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sunchao committed Feb 26, 2024
1 parent af68ece commit 59a2acb
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,6 @@ private[spark] class Executor(
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)
task.setBlockManager(env.blockManager)

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
Expand Down
9 changes: 2 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.storage.BlockManager
import org.apache.spark.util._

/**
Expand Down Expand Up @@ -94,7 +93,8 @@ private[spark] abstract class Task[T](

require(cpus > 0, "CPUs per task should be > 0")

SparkEnv.get.blockManager.registerTask(taskAttemptId)
val blockManager = SparkEnv.get.blockManager
blockManager.registerTask(taskAttemptId)
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
// the stage is barrier.
val taskContext = new TaskContextImpl(
Expand Down Expand Up @@ -165,16 +165,11 @@ private[spark] abstract class Task[T](
}

private var taskMemoryManager: TaskMemoryManager = _
private var blockManager: BlockManager = _

def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
this.taskMemoryManager = taskMemoryManager
}

def setBlockManager(blockManager: BlockManager): Unit = {
this.blockManager = blockManager
}

def runTask(context: TaskContext): T

def preferredLocations: Seq[TaskLocation] = Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -27,9 +28,11 @@ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.METRICS_CONF
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.memory.{MemoryManager, TaskMemoryManager}
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -680,14 +683,99 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
context.markTaskCompleted(None)
assert(isFailed)
}

test("Ensure the right block manager is used to unroll memory for task") {
import BlockManagerValidationPlugin._
BlockManagerValidationPlugin.resetState()

// run a task which ignores thread interruption when spark context is shutdown
sc = new SparkContext("local", "test")

val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))

override def compute(split: Partition, context: TaskContext): Iterator[String] = {
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
var done = false
while (!done) {
try {
releaseTaskSem.acquire(1)
done = true
} catch {
case iEx: InterruptedException =>
// ignore thread interruption
logInfo("Ignoring thread interruption", iEx)
}
}
}
})
taskMemoryManager.set(SparkEnv.get.blockManager.memoryManager)
taskStartedSem.release()
Iterator("hi")
}
}
// submit the job, but dont block this thread
rdd.collectAsync()
// wait for task to start
taskStartedSem.acquire(1)

sc.stop()
assert(sc.isStopped)

// create a new SparkContext
val conf = new SparkConf()
conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName)
BlockManagerValidationPlugin.threadLocalState.set(
() => {
val tmm = taskMemoryManager.get()
tmm.synchronized {
releaseTaskSem.release(1)
tmm.wait()
}
Thread.sleep(2500)
}
)
sc = new SparkContext("local", "test", conf)
}
}

private object TaskContextSuite {
private object TaskContextSuite extends Logging {
@volatile var completed = false

@volatile var lastError: Throwable = _

class FakeTaskFailureException extends Exception("Fake task failure")
}

class BlockManagerValidationPlugin extends SparkPlugin {

override def driverPlugin(): DriverPlugin = {
new DriverPlugin() {
// We dont really do anything - other than notifying that plugin creation has completed
// and then wait for a while
Option(BlockManagerValidationPlugin.threadLocalState.get()).foreach(_.apply())
}
}
override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin() {
// nothing to see here
}
}
}

object BlockManagerValidationPlugin {
val threadLocalState = new ThreadLocal[() => Unit]()

val releaseTaskSem = new Semaphore(0)
val taskMemoryManager = new AtomicReference[MemoryManager](null)
val taskStartedSem = new Semaphore(0)

def resetState(): Unit = {
releaseTaskSem.drainPermits()
taskStartedSem.drainPermits()
taskMemoryManager.set(null)
}
}

private case class StubPartition(index: Int) extends Partition

0 comments on commit 59a2acb

Please sign in to comment.