Skip to content

Commit 3ef183a

Browse files
zsxwingsrowen
authored andcommitted
[SPARK-19113][SS][TESTS] Set UncaughtExceptionHandler in onQueryStarted to ensure catching fatal errors during query initialization
## What changes were proposed in this pull request? StreamTest sets `UncaughtExceptionHandler` after starting the query now. It may not be able to catch fatal errors during query initialization. This PR uses `onQueryStarted` callback to fix it. ## How was this patch tested? Jenkins Author: Shixiong Zhu <shixiong@databricks.com> Closes #16492 from zsxwing/SPARK-19113.
1 parent a2c6adc commit 3ef183a

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class StreamSuite extends StreamTest {
238238
}
239239
}
240240

241-
testQuietly("fatal errors from a source should be sent to the user") {
241+
testQuietly("handle fatal errors thrown from the stream thread") {
242242
for (e <- Seq(
243243
new VirtualMachineError {},
244244
new ThreadDeath,
@@ -259,8 +259,11 @@ class StreamSuite extends StreamTest {
259259
override def stop(): Unit = {}
260260
}
261261
val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source))
262-
// These error are fatal errors and should be ignored in `testStream` to not fail the test.
263262
testStream(df)(
263+
// `ExpectFailure(isFatalError = true)` verifies two things:
264+
// - Fatal errors can be propagated to `StreamingQuery.exception` and
265+
// `StreamingQuery.awaitTermination` like non fatal errors.
266+
// - Fatal errors can be caught by UncaughtExceptionHandler.
264267
ExpectFailure(isFatalError = true)(ClassTag(e.getClass))
265268
)
266269
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
235235
*/
236236
def testStream(
237237
_stream: Dataset[_],
238-
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = {
238+
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized {
239+
// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
240+
// because this method assumes there is only one active query in its `StreamingQueryListener`
241+
// and it may not work correctly when multiple `testStream`s run concurrently.
239242

240243
val stream = _stream.toDF()
241244
val sparkSession = stream.sparkSession // use the session in DF, not the default session
@@ -248,6 +251,22 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
248251

249252
@volatile
250253
var streamThreadDeathCause: Throwable = null
254+
// Set UncaughtExceptionHandler in `onQueryStarted` so that we can ensure catching fatal errors
255+
// during query initialization.
256+
val listener = new StreamingQueryListener {
257+
override def onQueryStarted(event: QueryStartedEvent): Unit = {
258+
// Note: this assumes there is only one query active in the `testStream` method.
259+
Thread.currentThread.setUncaughtExceptionHandler(new UncaughtExceptionHandler {
260+
override def uncaughtException(t: Thread, e: Throwable): Unit = {
261+
streamThreadDeathCause = e
262+
}
263+
})
264+
}
265+
266+
override def onQueryProgress(event: QueryProgressEvent): Unit = {}
267+
override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {}
268+
}
269+
sparkSession.streams.addListener(listener)
251270

252271
// If the test doesn't manually start the stream, we do it automatically at the beginning.
253272
val startedManually =
@@ -364,12 +383,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
364383
triggerClock = triggerClock)
365384
.asInstanceOf[StreamingQueryWrapper]
366385
.streamingQuery
367-
currentStream.microBatchThread.setUncaughtExceptionHandler(
368-
new UncaughtExceptionHandler {
369-
override def uncaughtException(t: Thread, e: Throwable): Unit = {
370-
streamThreadDeathCause = e
371-
}
372-
})
373386
// Wait until the initialization finishes, because some tests need to use `logicalPlan`
374387
// after starting the query.
375388
currentStream.awaitInitialization(streamingTimeout.toMillis)
@@ -545,6 +558,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
545558
case (key, Some(value)) => sparkSession.conf.set(key, value)
546559
case (key, None) => sparkSession.conf.unset(key)
547560
}
561+
sparkSession.streams.removeListener(listener)
548562
}
549563
}
550564

0 commit comments

Comments
 (0)