@@ -19,17 +19,26 @@ package org.apache.spark.sql.streaming
19
19
20
20
import java .util .TimeZone
21
21
22
+ import scala .collection .mutable
23
+ import scala .reflect .runtime .{universe => ru }
24
+
25
+ import org .apache .hadoop .conf .Configuration
26
+ import org .mockito .Mockito
27
+ import org .mockito .invocation .InvocationOnMock
28
+ import org .mockito .stubbing .Answer
22
29
import org .scalatest .BeforeAndAfterAll
30
+ import org .scalatest .PrivateMethodTester ._
23
31
24
32
import org .apache .spark .SparkException
25
33
import org .apache .spark .sql .AnalysisException
26
- import org .apache .spark .sql .catalyst .util .DateTimeUtils
34
+ import org .apache .spark .sql .catalyst .util ._
27
35
import org .apache .spark .sql .execution .SparkPlan
28
36
import org .apache .spark .sql .execution .streaming ._
29
- import org .apache .spark .sql .execution .streaming .state .StateStore
37
+ import org .apache .spark .sql .execution .streaming .state ._
30
38
import org .apache .spark .sql .expressions .scalalang .typed
31
39
import org .apache .spark .sql .functions ._
32
40
import org .apache .spark .sql .streaming .OutputMode ._
41
+ import org .apache .spark .sql .types ._
33
42
34
43
object FailureSinglton {
35
44
var firstTime = true
@@ -335,4 +344,67 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
335
344
CheckLastBatch ((90L , 1 ), (100L , 1 ), (105L , 1 ))
336
345
)
337
346
}
347
+
348
+ test(" abort StateStore in case of error" ) {
349
+ quietly {
350
+ val inputData = MemoryStream [Long ]
351
+ val aggregated =
352
+ inputData.toDS()
353
+ .groupBy($" value" )
354
+ .agg(count(" *" ))
355
+ var aborted = false
356
+ testStream(aggregated, Complete )(
357
+ // This whole `AssertOnQuery` is used to inject a mock state store
358
+ AssertOnQuery (execution => {
359
+ // (1) Use reflection to get `StateStore.loadedProviders`
360
+ val loadedProviders = {
361
+ val field = ru.typeOf[StateStore .type ].decl(ru.TermName (" loadedProviders" )).asTerm
362
+ ru.runtimeMirror(StateStore .getClass.getClassLoader)
363
+ .reflect(StateStore )
364
+ .reflectField(field)
365
+ .get
366
+ .asInstanceOf [mutable.HashMap [StateStoreId , StateStoreProvider ]]
367
+ }
368
+ // (2) Make a storeId
369
+ val storeId = {
370
+ val checkpointLocation =
371
+ execution invokePrivate PrivateMethod [String ](' checkpointFile )(" state" )
372
+ StateStoreId (checkpointLocation, 0L , 0 )
373
+ }
374
+ // (3) Make `mockStore` and `mockProvider`
375
+ val (mockStore, mockProvider) = {
376
+ val keySchema = StructType (Seq (
377
+ StructField (" value" , LongType , false )))
378
+ val valueSchema = StructType (Seq (
379
+ StructField (" value" , LongType , false ), StructField (" count" , LongType , false )))
380
+ val storeConf = StateStoreConf .empty
381
+ val hadoopConf = new Configuration
382
+ (Mockito .spy(
383
+ StateStore .get(storeId, keySchema, valueSchema, version = 0 , storeConf, hadoopConf)),
384
+ Mockito .spy(loadedProviders.get(storeId).get))
385
+ }
386
+ // (4) Setup `mockStore` and `mockProvider`
387
+ Mockito .doAnswer(new Answer [Long ] {
388
+ override def answer (invocationOnMock : InvocationOnMock ): Long = {
389
+ sys.error(" injected error on commit()" )
390
+ }
391
+ }).when(mockStore).commit()
392
+ Mockito .doAnswer(new Answer [Unit ] {
393
+ override def answer (invocationOnMock : InvocationOnMock ): Unit = {
394
+ invocationOnMock.callRealMethod()
395
+ // Mark the flag for later check
396
+ aborted = true
397
+ }
398
+ }).when(mockStore).abort()
399
+ Mockito .doReturn(mockStore).when(mockProvider).getStore(version = 0 )
400
+ // (5) Inject `mockProvider`, which later on would inject `mockStore`
401
+ loadedProviders.put(storeId, mockProvider)
402
+ true
403
+ }), // End of AssertOnQuery, i.e. end of injecting `mockStore`
404
+ AddData (inputData, 1L , 2L , 3L ),
405
+ ExpectFailure [SparkException ](),
406
+ AssertOnQuery { _ => aborted } // Check that `mockStore.abort()` is called upon error
407
+ )
408
+ }
409
+ }
338
410
}
0 commit comments