@@ -248,6 +248,21 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
248
248
249
249
@ volatile
250
250
var streamThreadDeathCause : Throwable = null
251
+ // Set UncaughtExceptionHandler in `onQueryStarted` so that we can ensure catching fatal errors
252
+ // during query initialization.
253
+ val listener = new StreamingQueryListener {
254
+ override def onQueryStarted (event : QueryStartedEvent ): Unit = {
255
+ Thread .currentThread.setUncaughtExceptionHandler(new UncaughtExceptionHandler {
256
+ override def uncaughtException (t : Thread , e : Throwable ): Unit = {
257
+ streamThreadDeathCause = e
258
+ }
259
+ })
260
+ }
261
+
262
+ override def onQueryProgress (event : QueryProgressEvent ): Unit = {}
263
+ override def onQueryTerminated (event : QueryTerminatedEvent ): Unit = {}
264
+ }
265
+ sparkSession.streams.addListener(listener)
251
266
252
267
// If the test doesn't manually start the stream, we do it automatically at the beginning.
253
268
val startedManually =
@@ -364,12 +379,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
364
379
triggerClock = triggerClock)
365
380
.asInstanceOf [StreamingQueryWrapper ]
366
381
.streamingQuery
367
- currentStream.microBatchThread.setUncaughtExceptionHandler(
368
- new UncaughtExceptionHandler {
369
- override def uncaughtException (t : Thread , e : Throwable ): Unit = {
370
- streamThreadDeathCause = e
371
- }
372
- })
373
382
// Wait until the initialization finishes, because some tests need to use `logicalPlan`
374
383
// after starting the query.
375
384
currentStream.awaitInitialization(streamingTimeout.toMillis)
@@ -545,6 +554,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
545
554
case (key, Some (value)) => sparkSession.conf.set(key, value)
546
555
case (key, None ) => sparkSession.conf.unset(key)
547
556
}
557
+ sparkSession.streams.removeListener(listener)
548
558
}
549
559
}
550
560
0 commit comments