Skip to content

rm acquirelock from rocksdb and move tests #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ import java.util.Set
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong}
import javax.annotation.concurrent.GuardedBy

import scala.collection.{mutable, Map}
import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala
import scala.ref.WeakReference
import scala.util.Try

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -148,8 +146,6 @@ class RocksDB(
private val byteArrayPair = new ByteArrayPair()
private val commitLatencyMs = new mutable.HashMap[String, Long]()

private val acquireLock = new Object

@volatile private var db: NativeRocksDB = _
@volatile private var changelogWriter: Option[StateStoreChangelogWriter] = None
private val enableChangelogCheckpointing: Boolean = conf.enableChangelogCheckpointing
Expand Down Expand Up @@ -185,24 +181,15 @@ class RocksDB(

// SPARK-46249 - Keep track of recorded metrics per version which can be used for querying later
// Updates and access to recordedMetrics are protected by the DB instance lock
@GuardedBy("acquireLock")
@volatile private var recordedMetrics: Option[RocksDBMetrics] = None

@GuardedBy("acquireLock")
@volatile private var acquiredThreadInfo: AcquiredThreadInfo = _

// This is accessed and updated only between load and commit
// which means it is implicitly guarded by acquireLock
@GuardedBy("acquireLock")
private val colFamilyNameToInfoMap = new ConcurrentHashMap[String, ColumnFamilyInfo]()

@GuardedBy("acquireLock")
private val colFamilyIdToNameMap = new ConcurrentHashMap[Short, String]()

@GuardedBy("acquireLock")
private val maxColumnFamilyId: AtomicInteger = new AtomicInteger(-1)

@GuardedBy("acquireLock")
private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false)

private def getColumnFamilyInfo(cfName: String): ColumnFamilyInfo = {
Expand Down Expand Up @@ -294,7 +281,6 @@ class RocksDB(
// This mapping should only be updated using the Task thread - at version load and commit time.
// If same mapping instance is updated from different threads,
// it will result in undefined behavior (and most likely incorrect mapping state).
@GuardedBy("acquireLock")
private val rocksDBFileMapping: RocksDBFileMapping = new RocksDBFileMapping()

// We send snapshots that needs to be uploaded by the maintenance thread to this queue
Expand Down Expand Up @@ -1215,7 +1201,6 @@ class RocksDB(

/** Release all resources */
def close(): Unit = {
// Acquire DB instance lock and release at the end to allow for synchronized access
try {
closeDB()

Expand Down Expand Up @@ -1366,11 +1351,6 @@ class RocksDB(
}
}

private[state] def getAcquiredThreadInfo(): Option[AcquiredThreadInfo] =
acquireLock.synchronized {
Option(acquiredThreadInfo).map(_.copy())
}

/** Upload the snapshot to DFS and remove it from snapshots pending */
private def uploadSnapshot(
snapshot: RocksDBSnapshot,
Expand Down Expand Up @@ -1927,21 +1907,6 @@ object RocksDBNativeHistogram {
}
}

case class AcquiredThreadInfo(
threadRef: WeakReference[Thread] = new WeakReference[Thread](Thread.currentThread()),
tc: TaskContext = TaskContext.get()) {
override def toString(): String = {
val taskStr = if (tc != null) {
val taskDetails =
s"partition ${tc.partitionId()}.${tc.attemptNumber()} in stage " +
s"${tc.stageId()}.${tc.stageAttemptNumber()}, TID ${tc.taskAttemptId()}"
s", task: $taskDetails"
} else ""

s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
}
}

/**
* A helper class to manage the lineage information when checkpoint unique id is enabled.
* "lineage" is an array of LineageItem (version, uniqueId) pair.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ private[sql] class RocksDBStateStoreProvider

override def stateStoreId: StateStoreId = stateStoreId_

private lazy val stateMachine: RocksDBStateStoreProviderStateMachine =
// Exposed for testing
private[state] lazy val stateMachine: RocksDBStateStoreProviderStateMachine =
new RocksDBStateStoreProviderStateMachine(stateStoreId, RocksDBConf(storeConf))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
import javax.annotation.concurrent.GuardedBy

import scala.ref.WeakReference

import org.apache.spark.TaskContext
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.errors.QueryExecutionErrors

Expand Down Expand Up @@ -70,11 +73,15 @@ class RocksDBStateStoreProviderStateMachine(
private var state: STATE = RELEASED
@GuardedBy("instanceLock")
private var acquiredThreadInfo: AcquiredThreadInfo = _
// Exposed for testing
private[spark] def getAcquiredThreadInfo: Option[AcquiredThreadInfo] = instanceLock.synchronized {
Option(acquiredThreadInfo).map(_.copy())
}

// Can be read without holding any locks, but should only be updated when
// instanceLock is held.
// -1 indicates that the store is not locked.
private[sql] val currentValidStamp = new AtomicLong(-1L)
private[state] val currentValidStamp = new AtomicLong(-1L)
@GuardedBy("instanceLock")
private var lastValidStamp: Long = 0L

Expand Down Expand Up @@ -187,3 +194,18 @@ class RocksDBStateStoreProviderStateMachine(
validateAndTransitionState(CLOSE)
}
}

case class AcquiredThreadInfo(
threadRef: WeakReference[Thread] = new WeakReference[Thread](Thread.currentThread()),
tc: TaskContext = TaskContext.get()) {
override def toString(): String = {
val taskStr = if (tc != null) {
val taskDetails =
s"partition ${tc.partitionId()}.${tc.attemptNumber()} in stage " +
s"${tc.stageId()}.${tc.stageAttemptNumber()}, TID ${tc.taskAttemptId()}"
s", task: $taskDetails"
} else ""

s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming.state
import java.util.UUID

import scala.collection.immutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.util.Random

import org.apache.avro.AvroTypeException
import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkRuntimeException, SparkUnsupportedOperationException}
import org.apache.spark.{SparkConf, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.LocalSparkSession.withSparkSession
import org.apache.spark.sql.SparkSession
Expand All @@ -40,7 +42,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.tags.ExtendedSQLTest
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}

@ExtendedSQLTest
class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
Expand Down Expand Up @@ -2321,6 +2323,164 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
}
}

test("Rocks DB task completion listener does not double unlock acquireThread") {
// This test verifies that a thread that locks then unlocks the db and then
// fires a completion listener (Thread 1) does not unlock the lock validly
// acquired by another thread (Thread 2).
//
// Timeline of this test (* means thread is active):
// STATE | MAIN | THREAD 1 | THREAD 2 |
// ------| ---------------- | ---------------- | ---------------- |
// 0. | wait for s3 | *load, commit | wait for s1 |
// | | *signal s1 | |
// ------| ---------------- | ---------------- | ---------------- |
// 1. | | wait for s2 | *load, signal s2 |
// ------| ---------------- | ---------------- | ---------------- |
// 2. | | *task complete | wait for s4 |
// | | *signal s3, END | |
// ------| ---------------- | ---------------- | ---------------- |
// 3. | *verify locked | | |
// | *signal s4 | | |
// ------| ---------------- | ---------------- | ---------------- |
// 4. | wait for s5 | | *commit |
// | | | *signal s5, END |
// ------| ---------------- | ---------------- | ---------------- |
// 5. | *close db, END | | |
//
// NOTE: state 4 and 5 are only for cleanup

// Create a custom ExecutionContext with 3 threads
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
ThreadUtils.newDaemonFixedThreadPool(3, "pool-thread-executor"))
val stateLock = new Object()
var state = 0

tryWithProviderResource(newStoreProvider()) { provider =>
Future { // THREAD 1
// Set thread 1's task context so that it is not a clone
// of the main thread's taskContext, which will end if the
// task is marked as complete
val taskContext = TaskContext.empty()
TaskContext.setTaskContext(taskContext)

stateLock.synchronized {
// -------------------- STATE 0 --------------------
// Simulate a task that loads and commits, db should be unlocked after
val store = provider.getStore(0)
store.commit()
// Signal that we have entered state 1
state = 1
stateLock.notifyAll()

// -------------------- STATE 2 --------------------
// Wait until we have entered state 2 (thread 2 has loaded db and acquired lock)
while (state != 2) {
stateLock.wait()
}

// thread 1's task context is marked as complete and signal
// that we have entered state 3
// At this point, thread 2 should still hold the DB lock.
taskContext.markTaskCompleted(None)
state = 3
stateLock.notifyAll()
}
}

Future { // THREAD 2
// Set thread 2's task context so that it is not a clone of thread 1's
// so it won't be marked as complete
val taskContext = TaskContext.empty()
TaskContext.setTaskContext(taskContext)

stateLock.synchronized {
// -------------------- STATE 1 --------------------
// Wait until we have entered state 1 (thread 1 finished loading and committing)
while (state != 1) {
stateLock.wait()
}

// Load the db and signal that we have entered state 2
val store = provider.getStore(1)
assertAcquiredThreadIsCurrentThread(provider)
state = 2
stateLock.notifyAll()

// -------------------- STATE 4 --------------------
// Wait until we have entered state 4 (thread 1 completed and
// main thread confirmed that lock is held)
while (state != 4) {
stateLock.wait()
}

// Ensure we still have the lock
assertAcquiredThreadIsCurrentThread(provider)

// commit and signal that we have entered state 5
store.commit()
state = 5
stateLock.notifyAll()
}
}

// MAIN THREAD
stateLock.synchronized {
// -------------------- STATE 3 --------------------
// Wait until we have entered state 3 (thread 1 is complete)
while (state != 3) {
stateLock.wait()
}

// Verify that the lock is being held
val threadInfo = provider.stateMachine.getAcquiredThreadInfo
assert(threadInfo.nonEmpty, s"acquiredThreadInfo was None when it should be Some")

// Signal that we have entered state 4 (thread 2 can now release lock)
state = 4
stateLock.notifyAll()

// -------------------- STATE 5 --------------------
// Wait until we have entered state 5 (thread 2 has released lock)
// so that we can clean up
while (state != 5) {
stateLock.wait()
}
}
}
}

test("RocksDB task completion listener correctly releases for failed task") {
// This test verifies that a thread that locks the DB and then fails
// can rely on the completion listener to release the lock.

// Create a custom ExecutionContext with 1 thread
implicit val ec: ExecutionContext = ExecutionContext.fromExecutor(
ThreadUtils.newDaemonSingleThreadExecutor("single-thread-executor"))
val timeout = 5.seconds

tryWithProviderResource(newStoreProvider()) { provider =>
// New task that will load and then complete with failure
val fut = Future {
val taskContext = TaskContext.empty()
TaskContext.setTaskContext(taskContext)

provider.getStore(0)
assertAcquiredThreadIsCurrentThread(provider)

// Task completion listener should unlock
taskContext.markTaskCompleted(
Some(new SparkException("Task failure injection")))
}

ThreadUtils.awaitResult(fut, timeout)

// Assert that db is not locked
val stamp = provider.stateMachine.currentValidStamp.get()
assert(stamp == -1,
s"state machine stamp should be -1 (unlocked) but was $stamp")
}
}

override def newStoreProvider(): RocksDBStateStoreProvider = {
newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
}
Expand Down Expand Up @@ -2451,5 +2611,16 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
}
}
}

def assertAcquiredThreadIsCurrentThread(provider: RocksDBStateStoreProvider): Unit = {
val threadInfo = provider.stateMachine.getAcquiredThreadInfo
assert(threadInfo.isDefined,
"acquired thread info should not be null after load")
val threadId = threadInfo.get.threadRef.get.get.getId
assert(
threadId == Thread.currentThread().getId,
s"acquired thread should be curent thread ${Thread.currentThread().getId} " +
s"after load but was $threadId")
}
}

Loading
Loading