Skip to content

Commit 253d70b

Browse files
authored
rm acquirelock from rocksdb and move tests (#26)
* rm acquirelock from rocksdb and move tests * fix tests
1 parent 25f623b commit 253d70b

File tree

5 files changed

+199
-228
lines changed

5 files changed

+199
-228
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ import java.util.Set
2323
import java.util.UUID
2424
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
2525
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong}
26-
import javax.annotation.concurrent.GuardedBy
2726

2827
import scala.collection.{mutable, Map}
2928
import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala
30-
import scala.ref.WeakReference
3129
import scala.util.Try
3230

3331
import org.apache.hadoop.conf.Configuration
@@ -148,8 +146,6 @@ class RocksDB(
148146
private val byteArrayPair = new ByteArrayPair()
149147
private val commitLatencyMs = new mutable.HashMap[String, Long]()
150148

151-
private val acquireLock = new Object
152-
153149
@volatile private var db: NativeRocksDB = _
154150
@volatile private var changelogWriter: Option[StateStoreChangelogWriter] = None
155151
private val enableChangelogCheckpointing: Boolean = conf.enableChangelogCheckpointing
@@ -185,24 +181,15 @@ class RocksDB(
185181

186182
// SPARK-46249 - Keep track of recorded metrics per version which can be used for querying later
187183
// Updates and access to recordedMetrics are protected by the DB instance lock
188-
@GuardedBy("acquireLock")
189184
@volatile private var recordedMetrics: Option[RocksDBMetrics] = None
190185

191-
@GuardedBy("acquireLock")
192-
@volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
193-
194186
// This is accessed and updated only between load and commit
195-
// which means it is implicitly guarded by acquireLock
196-
@GuardedBy("acquireLock")
197187
private val colFamilyNameToInfoMap = new ConcurrentHashMap[String, ColumnFamilyInfo]()
198188

199-
@GuardedBy("acquireLock")
200189
private val colFamilyIdToNameMap = new ConcurrentHashMap[Short, String]()
201190

202-
@GuardedBy("acquireLock")
203191
private val maxColumnFamilyId: AtomicInteger = new AtomicInteger(-1)
204192

205-
@GuardedBy("acquireLock")
206193
private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false)
207194

208195
private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = {
@@ -294,7 +281,6 @@ class RocksDB(
294281
// This mapping should only be updated using the Task thread - at version load and commit time.
295282
// If same mapping instance is updated from different threads,
296283
// it will result in undefined behavior (and most likely incorrect mapping state).
297-
@GuardedBy("acquireLock")
298284
private val rocksDBFileMapping: RocksDBFileMapping = new RocksDBFileMapping()
299285

300286
// We send snapshots that needs to be uploaded by the maintenance thread to this queue
@@ -1215,7 +1201,6 @@ class RocksDB(
12151201

12161202
/** Release all resources */
12171203
def close(): Unit = {
1218-
// Acquire DB instance lock and release at the end to allow for synchronized access
12191204
try {
12201205
closeDB()
12211206

@@ -1366,11 +1351,6 @@ class RocksDB(
13661351
}
13671352
}
13681353

1369-
private[state] def getAcquiredThreadInfo(): Option[AcquiredThreadInfo] =
1370-
acquireLock.synchronized {
1371-
Option(acquiredThreadInfo).map(_.copy())
1372-
}
1373-
13741354
/** Upload the snapshot to DFS and remove it from snapshots pending */
13751355
private def uploadSnapshot(
13761356
snapshot: RocksDBSnapshot,
@@ -1927,21 +1907,6 @@ object RocksDBNativeHistogram {
19271907
}
19281908
}
19291909

1930-
case class AcquiredThreadInfo(
1931-
threadRef: WeakReference[Thread] = new WeakReference[Thread](Thread.currentThread()),
1932-
tc: TaskContext = TaskContext.get()) {
1933-
override def toString(): String = {
1934-
val taskStr = if (tc != null) {
1935-
val taskDetails =
1936-
s"partition ${tc.partitionId()}.${tc.attemptNumber()} in stage " +
1937-
s"${tc.stageId()}.${tc.stageAttemptNumber()}, TID ${tc.taskAttemptId()}"
1938-
s", task: $taskDetails"
1939-
} else ""
1940-
1941-
s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
1942-
}
1943-
}
1944-
19451910
/**
19461911
* A helper class to manage the lineage information when checkpoint unique id is enabled.
19471912
* "lineage" is an array of LineageItem (version, uniqueId) pair.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,8 @@ private[sql] class RocksDBStateStoreProvider
610610

611611
override def stateStoreId: StateStoreId = stateStoreId_
612612

613-
private lazy val stateMachine: RocksDBStateStoreProviderStateMachine =
613+
// Exposed for testing
614+
private[state] lazy val stateMachine: RocksDBStateStoreProviderStateMachine =
614615
new RocksDBStateStoreProviderStateMachine(stateStoreId, RocksDBConf(storeConf))
615616

616617
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProviderStateMachine.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import java.util.concurrent.TimeUnit
2121
import java.util.concurrent.atomic.AtomicLong
2222
import javax.annotation.concurrent.GuardedBy
2323

24+
import scala.ref.WeakReference
25+
26+
import org.apache.spark.TaskContext
2427
import org.apache.spark.internal.{Logging, LogKeys, MDC}
2528
import org.apache.spark.sql.errors.QueryExecutionErrors
2629

@@ -70,11 +73,15 @@ class RocksDBStateStoreProviderStateMachine(
7073
private var state: STATE = RELEASED
7174
@GuardedBy("instanceLock")
7275
private var acquiredThreadInfo: AcquiredThreadInfo = _
76+
// Exposed for testing
77+
private[spark] def getAcquiredThreadInfo: Option[AcquiredThreadInfo] = instanceLock.synchronized {
78+
Option(acquiredThreadInfo).map(_.copy())
79+
}
7380

7481
// Can be read without holding any locks, but should only be updated when
7582
// instanceLock is held.
7683
// -1 indicates that the store is not locked.
77-
private[sql] val currentValidStamp = new AtomicLong(-1L)
84+
private[state] val currentValidStamp = new AtomicLong(-1L)
7885
@GuardedBy("instanceLock")
7986
private var lastValidStamp: Long = 0L
8087

@@ -187,3 +194,18 @@ class RocksDBStateStoreProviderStateMachine(
187194
validateAndTransitionState(CLOSE)
188195
}
189196
}
197+
198+
case class AcquiredThreadInfo(
199+
threadRef: WeakReference[Thread] = new WeakReference[Thread](Thread.currentThread()),
200+
tc: TaskContext = TaskContext.get()) {
201+
override def toString(): String = {
202+
val taskStr = if (tc != null) {
203+
val taskDetails =
204+
s"partition ${tc.partitionId()}.${tc.attemptNumber()} in stage " +
205+
s"${tc.stageId()}.${tc.stageAttemptNumber()}, TID ${tc.taskAttemptId()}"
206+
s", task: $taskDetails"
207+
} else ""
208+
209+
s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
210+
}
211+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala

Lines changed: 173 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming.state
2020
import java.util.UUID
2121

2222
import scala.collection.immutable
23+
import scala.concurrent.{ExecutionContext, Future}
24+
import scala.concurrent.duration._
2325
import scala.util.Random
2426

2527
import org.apache.avro.AvroTypeException
2628
import org.apache.hadoop.conf.Configuration
2729
import org.scalatest.BeforeAndAfter
2830

29-
import org.apache.spark.{SparkConf, SparkRuntimeException, SparkUnsupportedOperationException}
31+
import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, TaskContext}
3032
import org.apache.spark.io.CompressionCodec
3133
import org.apache.spark.sql.LocalSparkSession.withSparkSession
3234
import org.apache.spark.sql.SparkSession
@@ -40,7 +42,7 @@ import org.apache.spark.sql.types._
4042
import org.apache.spark.tags.ExtendedSQLTest
4143
import org.apache.spark.unsafe.Platform
4244
import org.apache.spark.unsafe.types.UTF8String
43-
import org.apache.spark.util.Utils
45+
import org.apache.spark.util.{ThreadUtils, Utils}
4446

4547
@ExtendedSQLTest
4648
class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
@@ -2321,6 +2323,164 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
23212323
}
23222324
}
23232325

2326+
test("Rocks DB task completion listener does not double unlock acquireThread") {
2327+
// This test verifies that a thread that locks then unlocks the db and then
2328+
// fires a completion listener (Thread 1) does not unlock the lock validly
2329+
// acquired by another thread (Thread 2).
2330+
//
2331+
// Timeline of this test (* means thread is active):
2332+
// STATE | MAIN | THREAD 1 | THREAD 2 |
2333+
// ------| ---------------- | ---------------- | ---------------- |
2334+
// 0. | wait for s3 | *load, commit | wait for s1 |
2335+
// | | *signal s1 | |
2336+
// ------| ---------------- | ---------------- | ---------------- |
2337+
// 1. | | wait for s2 | *load, signal s2 |
2338+
// ------| ---------------- | ---------------- | ---------------- |
2339+
// 2. | | *task complete | wait for s4 |
2340+
// | | *signal s3, END | |
2341+
// ------| ---------------- | ---------------- | ---------------- |
2342+
// 3. | *verify locked | | |
2343+
// | *signal s4 | | |
2344+
// ------| ---------------- | ---------------- | ---------------- |
2345+
// 4. | wait for s5 | | *commit |
2346+
// | | | *signal s5, END |
2347+
// ------| ---------------- | ---------------- | ---------------- |
2348+
// 5. | *close db, END | | |
2349+
//
2350+
// NOTE: state 4 and 5 are only for cleanup
2351+
2352+
// Create a custom ExecutionContext with 3 threads
2353+
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
2354+
ThreadUtils.newDaemonFixedThreadPool(3, "pool-thread-executor"))
2355+
val stateLock = new Object()
2356+
var state = 0
2357+
2358+
tryWithProviderResource(newStoreProvider()) { provider =>
2359+
Future { // THREAD 1
2360+
// Set thread 1's task context so that it is not a clone
2361+
// of the main thread's taskContext, which will end if the
2362+
// task is marked as complete
2363+
val taskContext = TaskContext.empty()
2364+
TaskContext.setTaskContext(taskContext)
2365+
2366+
stateLock.synchronized {
2367+
// -------------------- STATE 0 --------------------
2368+
// Simulate a task that loads and commits, db should be unlocked after
2369+
val store = provider.getStore(0)
2370+
store.commit()
2371+
// Signal that we have entered state 1
2372+
state = 1
2373+
stateLock.notifyAll()
2374+
2375+
// -------------------- STATE 2 --------------------
2376+
// Wait until we have entered state 2 (thread 2 has loaded db and acquired lock)
2377+
while (state != 2) {
2378+
stateLock.wait()
2379+
}
2380+
2381+
// thread 1's task context is marked as complete and signal
2382+
// that we have entered state 3
2383+
// At this point, thread 2 should still hold the DB lock.
2384+
taskContext.markTaskCompleted(None)
2385+
state = 3
2386+
stateLock.notifyAll()
2387+
}
2388+
}
2389+
2390+
Future { // THREAD 2
2391+
// Set thread 2's task context so that it is not a clone of thread 1's
2392+
// so it won't be marked as complete
2393+
val taskContext = TaskContext.empty()
2394+
TaskContext.setTaskContext(taskContext)
2395+
2396+
stateLock.synchronized {
2397+
// -------------------- STATE 1 --------------------
2398+
// Wait until we have entered state 1 (thread 1 finished loading and committing)
2399+
while (state != 1) {
2400+
stateLock.wait()
2401+
}
2402+
2403+
// Load the db and signal that we have entered state 2
2404+
val store = provider.getStore(1)
2405+
assertAcquiredThreadIsCurrentThread(provider)
2406+
state = 2
2407+
stateLock.notifyAll()
2408+
2409+
// -------------------- STATE 4 --------------------
2410+
// Wait until we have entered state 4 (thread 1 completed and
2411+
// main thread confirmed that lock is held)
2412+
while (state != 4) {
2413+
stateLock.wait()
2414+
}
2415+
2416+
// Ensure we still have the lock
2417+
assertAcquiredThreadIsCurrentThread(provider)
2418+
2419+
// commit and signal that we have entered state 5
2420+
store.commit()
2421+
state = 5
2422+
stateLock.notifyAll()
2423+
}
2424+
}
2425+
2426+
// MAIN THREAD
2427+
stateLock.synchronized {
2428+
// -------------------- STATE 3 --------------------
2429+
// Wait until we have entered state 3 (thread 1 is complete)
2430+
while (state != 3) {
2431+
stateLock.wait()
2432+
}
2433+
2434+
// Verify that the lock is being held
2435+
val threadInfo = provider.stateMachine.getAcquiredThreadInfo
2436+
assert(threadInfo.nonEmpty, s"acquiredThreadInfo was None when it should be Some")
2437+
2438+
// Signal that we have entered state 4 (thread 2 can now release lock)
2439+
state = 4
2440+
stateLock.notifyAll()
2441+
2442+
// -------------------- STATE 5 --------------------
2443+
// Wait until we have entered state 5 (thread 2 has released lock)
2444+
// so that we can clean up
2445+
while (state != 5) {
2446+
stateLock.wait()
2447+
}
2448+
}
2449+
}
2450+
}
2451+
2452+
test("RocksDB task completion listener correctly releases for failed task") {
2453+
// This test verifies that a thread that locks the DB and then fails
2454+
// can rely on the completion listener to release the lock.
2455+
2456+
// Create a custom ExecutionContext with 1 thread
2457+
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
2458+
ThreadUtils.newDaemonSingleThreadExecutor("single-thread-executor"))
2459+
val timeout = 5.seconds
2460+
2461+
tryWithProviderResource(newStoreProvider()) { provider =>
2462+
// New task that will load and then complete with failure
2463+
val fut = Future {
2464+
val taskContext = TaskContext.empty()
2465+
TaskContext.setTaskContext(taskContext)
2466+
2467+
provider.getStore(0)
2468+
assertAcquiredThreadIsCurrentThread(provider)
2469+
2470+
// Task completion listener should unlock
2471+
taskContext.markTaskCompleted(
2472+
Some(new SparkException("Task failure injection")))
2473+
}
2474+
2475+
ThreadUtils.awaitResult(fut, timeout)
2476+
2477+
// Assert that db is not locked
2478+
val stamp = provider.stateMachine.currentValidStamp.get()
2479+
assert(stamp == -1,
2480+
s"state machine stamp should be -1 (unlocked) but was $stamp")
2481+
}
2482+
}
2483+
23242484
override def newStoreProvider(): RocksDBStateStoreProvider = {
23252485
newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
23262486
}
@@ -2451,5 +2611,16 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
24512611
}
24522612
}
24532613
}
2614+
2615+
def assertAcquiredThreadIsCurrentThread(provider: RocksDBStateStoreProvider): Unit = {
2616+
val threadInfo = provider.stateMachine.getAcquiredThreadInfo
2617+
assert(threadInfo.isDefined,
2618+
"acquired thread info should not be null after load")
2619+
val threadId = threadInfo.get.threadRef.get.get.getId
2620+
assert(
2621+
threadId == Thread.currentThread().getId,
2622+
s"acquired thread should be curent thread ${Thread.currentThread().getId} " +
2623+
s"after load but was $threadId")
2624+
}
24542625
}
24552626

0 commit comments

Comments
 (0)