@@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming.state
20
20
import java .util .UUID
21
21
22
22
import scala .collection .immutable
23
+ import scala .concurrent .{ExecutionContext , Future }
24
+ import scala .concurrent .duration ._
23
25
import scala .util .Random
24
26
25
27
import org .apache .avro .AvroTypeException
26
28
import org .apache .hadoop .conf .Configuration
27
29
import org .scalatest .BeforeAndAfter
28
30
29
- import org .apache .spark .{SparkConf , SparkRuntimeException , SparkUnsupportedOperationException }
31
+ import org .apache .spark .{SparkConf , SparkException , SparkRuntimeException , SparkUnsupportedOperationException , TaskContext }
30
32
import org .apache .spark .io .CompressionCodec
31
33
import org .apache .spark .sql .LocalSparkSession .withSparkSession
32
34
import org .apache .spark .sql .SparkSession
@@ -40,7 +42,7 @@ import org.apache.spark.sql.types._
40
42
import org .apache .spark .tags .ExtendedSQLTest
41
43
import org .apache .spark .unsafe .Platform
42
44
import org .apache .spark .unsafe .types .UTF8String
43
- import org .apache .spark .util .Utils
45
+ import org .apache .spark .util .{ ThreadUtils , Utils }
44
46
45
47
@ ExtendedSQLTest
46
48
class RocksDBStateStoreSuite extends StateStoreSuiteBase [RocksDBStateStoreProvider ]
@@ -2321,6 +2323,164 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
2321
2323
}
2322
2324
}
2323
2325
2326
+ test(" Rocks DB task completion listener does not double unlock acquireThread" ) {
2327
+ // This test verifies that a thread that locks then unlocks the db and then
2328
+ // fires a completion listener (Thread 1) does not unlock the lock validly
2329
+ // acquired by another thread (Thread 2).
2330
+ //
2331
+ // Timeline of this test (* means thread is active):
2332
+ // STATE | MAIN | THREAD 1 | THREAD 2 |
2333
+ // ------| ---------------- | ---------------- | ---------------- |
2334
+ // 0. | wait for s3 | *load, commit | wait for s1 |
2335
+ // | | *signal s1 | |
2336
+ // ------| ---------------- | ---------------- | ---------------- |
2337
+ // 1. | | wait for s2 | *load, signal s2 |
2338
+ // ------| ---------------- | ---------------- | ---------------- |
2339
+ // 2. | | *task complete | wait for s4 |
2340
+ // | | *signal s3, END | |
2341
+ // ------| ---------------- | ---------------- | ---------------- |
2342
+ // 3. | *verify locked | | |
2343
+ // | *signal s4 | | |
2344
+ // ------| ---------------- | ---------------- | ---------------- |
2345
+ // 4. | wait for s5 | | *commit |
2346
+ // | | | *signal s5, END |
2347
+ // ------| ---------------- | ---------------- | ---------------- |
2348
+ // 5. | *close db, END | | |
2349
+ //
2350
+ // NOTE: state 4 and 5 are only for cleanup
2351
+
2352
+ // Create a custom ExecutionContext with 3 threads
2353
+ implicit val ec : ExecutionContext = ExecutionContext .fromExecutor(
2354
+ ThreadUtils .newDaemonFixedThreadPool(3 , " pool-thread-executor" ))
2355
+ val stateLock = new Object ()
2356
+ var state = 0
2357
+
2358
+ tryWithProviderResource(newStoreProvider()) { provider =>
2359
+ Future { // THREAD 1
2360
+ // Set thread 1's task context so that it is not a clone
2361
+ // of the main thread's taskContext, which will end if the
2362
+ // task is marked as complete
2363
+ val taskContext = TaskContext .empty()
2364
+ TaskContext .setTaskContext(taskContext)
2365
+
2366
+ stateLock.synchronized {
2367
+ // -------------------- STATE 0 --------------------
2368
+ // Simulate a task that loads and commits, db should be unlocked after
2369
+ val store = provider.getStore(0 )
2370
+ store.commit()
2371
+ // Signal that we have entered state 1
2372
+ state = 1
2373
+ stateLock.notifyAll()
2374
+
2375
+ // -------------------- STATE 2 --------------------
2376
+ // Wait until we have entered state 2 (thread 2 has loaded db and acquired lock)
2377
+ while (state != 2 ) {
2378
+ stateLock.wait()
2379
+ }
2380
+
2381
+ // thread 1's task context is marked as complete and signal
2382
+ // that we have entered state 3
2383
+ // At this point, thread 2 should still hold the DB lock.
2384
+ taskContext.markTaskCompleted(None )
2385
+ state = 3
2386
+ stateLock.notifyAll()
2387
+ }
2388
+ }
2389
+
2390
+ Future { // THREAD 2
2391
+ // Set thread 2's task context so that it is not a clone of thread 1's
2392
+ // so it won't be marked as complete
2393
+ val taskContext = TaskContext .empty()
2394
+ TaskContext .setTaskContext(taskContext)
2395
+
2396
+ stateLock.synchronized {
2397
+ // -------------------- STATE 1 --------------------
2398
+ // Wait until we have entered state 1 (thread 1 finished loading and committing)
2399
+ while (state != 1 ) {
2400
+ stateLock.wait()
2401
+ }
2402
+
2403
+ // Load the db and signal that we have entered state 2
2404
+ val store = provider.getStore(1 )
2405
+ assertAcquiredThreadIsCurrentThread(provider)
2406
+ state = 2
2407
+ stateLock.notifyAll()
2408
+
2409
+ // -------------------- STATE 4 --------------------
2410
+ // Wait until we have entered state 4 (thread 1 completed and
2411
+ // main thread confirmed that lock is held)
2412
+ while (state != 4 ) {
2413
+ stateLock.wait()
2414
+ }
2415
+
2416
+ // Ensure we still have the lock
2417
+ assertAcquiredThreadIsCurrentThread(provider)
2418
+
2419
+ // commit and signal that we have entered state 5
2420
+ store.commit()
2421
+ state = 5
2422
+ stateLock.notifyAll()
2423
+ }
2424
+ }
2425
+
2426
+ // MAIN THREAD
2427
+ stateLock.synchronized {
2428
+ // -------------------- STATE 3 --------------------
2429
+ // Wait until we have entered state 3 (thread 1 is complete)
2430
+ while (state != 3 ) {
2431
+ stateLock.wait()
2432
+ }
2433
+
2434
+ // Verify that the lock is being held
2435
+ val threadInfo = provider.stateMachine.getAcquiredThreadInfo
2436
+ assert(threadInfo.nonEmpty, s " acquiredThreadInfo was None when it should be Some " )
2437
+
2438
+ // Signal that we have entered state 4 (thread 2 can now release lock)
2439
+ state = 4
2440
+ stateLock.notifyAll()
2441
+
2442
+ // -------------------- STATE 5 --------------------
2443
+ // Wait until we have entered state 5 (thread 2 has released lock)
2444
+ // so that we can clean up
2445
+ while (state != 5 ) {
2446
+ stateLock.wait()
2447
+ }
2448
+ }
2449
+ }
2450
+ }
2451
+
2452
+ test(" RocksDB task completion listener correctly releases for failed task" ) {
2453
+ // This test verifies that a thread that locks the DB and then fails
2454
+ // can rely on the completion listener to release the lock.
2455
+
2456
+ // Create a custom ExecutionContext with 1 thread
2457
+ implicit val ec : ExecutionContext = ExecutionContext .fromExecutor(
2458
+ ThreadUtils .newDaemonSingleThreadExecutor(" single-thread-executor" ))
2459
+ val timeout = 5 .seconds
2460
+
2461
+ tryWithProviderResource(newStoreProvider()) { provider =>
2462
+ // New task that will load and then complete with failure
2463
+ val fut = Future {
2464
+ val taskContext = TaskContext .empty()
2465
+ TaskContext .setTaskContext(taskContext)
2466
+
2467
+ provider.getStore(0 )
2468
+ assertAcquiredThreadIsCurrentThread(provider)
2469
+
2470
+ // Task completion listener should unlock
2471
+ taskContext.markTaskCompleted(
2472
+ Some (new SparkException (" Task failure injection" )))
2473
+ }
2474
+
2475
+ ThreadUtils .awaitResult(fut, timeout)
2476
+
2477
+ // Assert that db is not locked
2478
+ val stamp = provider.stateMachine.currentValidStamp.get()
2479
+ assert(stamp == - 1 ,
2480
+ s " state machine stamp should be -1 (unlocked) but was $stamp" )
2481
+ }
2482
+ }
2483
+
2324
2484
override def newStoreProvider (): RocksDBStateStoreProvider = {
2325
2485
newStoreProvider(StateStoreId (newDir(), Random .nextInt(), 0 ))
2326
2486
}
@@ -2451,5 +2611,16 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
2451
2611
}
2452
2612
}
2453
2613
}
2614
+
2615
+ def assertAcquiredThreadIsCurrentThread (provider : RocksDBStateStoreProvider ): Unit = {
2616
+ val threadInfo = provider.stateMachine.getAcquiredThreadInfo
2617
+ assert(threadInfo.isDefined,
2618
+ " acquired thread info should not be null after load" )
2619
+ val threadId = threadInfo.get.threadRef.get.get.getId
2620
+ assert(
2621
+ threadId == Thread .currentThread().getId,
2622
+ s " acquired thread should be curent thread ${Thread .currentThread().getId} " +
2623
+ s " after load but was $threadId" )
2624
+ }
2454
2625
}
2455
2626
0 commit comments