Skip to content

[SPARK-11761] Prevent the call to StreamingContext#stop() in the listener bus's thread #9741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.util

import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.DynamicVariable

import org.apache.spark.SparkContext

Expand Down Expand Up @@ -60,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
private val listenerThread = new Thread(name) {
setDaemon(true)
override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
while (true) {
eventLock.acquire()
self.synchronized {
processingEvent = true
}
try {
val event = eventQueue.poll
if (event == null) {
// Get out of the while loop and shutdown the daemon thread
if (!stopped.get) {
throw new IllegalStateException("Polling `null` from eventQueue means" +
" the listener bus has been stopped. So `stopped` must be true")
}
return
}
postToAll(event)
} finally {
AsynchronousListenerBus.withinListenerThread.withValue(true) {
while (true) {
eventLock.acquire()
self.synchronized {
processingEvent = false
processingEvent = true
}
try {
val event = eventQueue.poll
if (event == null) {
// Get out of the while loop and shutdown the daemon thread
if (!stopped.get) {
throw new IllegalStateException("Polling `null` from eventQueue means" +
" the listener bus has been stopped. So `stopped` must be true")
}
return
}
postToAll(event)
} finally {
self.synchronized {
processingEvent = false
}
}
}
}
Expand Down Expand Up @@ -177,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
*/
def onDropEvent(event: E): Unit
}

private[spark] object AsynchronousListenerBus {
/* Allows for Context to check whether stop() call is made within listener thread
*/
val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
}

Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver}
import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener}
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils}
import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils}

/**
* Main entry point for Spark Streaming functionality. It provides methods used to create
Expand Down Expand Up @@ -693,6 +693,10 @@ class StreamingContext private[streaming] (
*/
def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = {
var shutdownHookRefToRemove: AnyRef = null
if (AsynchronousListenerBus.withinListenerThread.value) {
throw new SparkException("Cannot stop StreamingContext within listener thread of" +
" AsynchronousListenerBus")
}
synchronized {
try {
state match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global

import org.apache.spark.SparkException
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
Expand Down Expand Up @@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
}
}

test("don't call ssc.stop in listener") {
ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000))
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
inputStream.foreachRDD(_.count)

startStreamingContextAndCallStop(ssc)
}

test("onBatchCompleted with successful batch") {
ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
Expand Down Expand Up @@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
assert(failureReasons(1).contains("This is another failed job"))
}

private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = {
val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc)
_ssc.addStreamingListener(contextStoppingCollector)
val batchCounter = new BatchCounter(_ssc)
_ssc.start()
// Make sure running at least one batch
batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)
_ssc.stop()
assert(contextStoppingCollector.sparkExSeen)
}

private def startStreamingContextAndCollectFailureReasons(
_ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = {
val failureReasonsCollector = new FailureReasonsCollector()
Expand Down Expand Up @@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener {
}
}
}
/**
* A StreamingListener that calls StreamingContext.stop().
*/
class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener {
@volatile var sparkExSeen = false
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
try {
ssc.stop()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the listener bus will just log the exception. You can catch the exception here and use a field to store it. Then you can assert the exception in the test.

} catch {
case se: SparkException =>
sparkExSeen = true
}
}
}