Skip to content

Provider state machine #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 15, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ abstract class StatePartitionReaderBase(
stateStoreColFamilySchema.keyStateEncoderSpec.get,
useMultipleValuesPerKey = useMultipleValuesPerKey,
isInternal = isInternal)
store.commit()
}
provider
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider
with SupportsFineGrainedReplay {
import RocksDBStateStoreProvider._

class RocksDBStateStore(lastVersion: Long) extends StateStore {
class RocksDBStateStore(lastVersion: Long, val stamp: Long) extends StateStore {
/** Trait and classes representing the internal state of the store */
trait STATE
case object UPDATING extends STATE
Expand All @@ -58,6 +58,10 @@ private[sql] class RocksDBStateStoreProvider

@volatile private var state: STATE = UPDATING

override def getReadStamp: Long = {
stamp
}

/**
* Validates the expected state, throws exception if state is not as expected.
* Returns the current state
Expand All @@ -81,6 +85,7 @@ private[sql] class RocksDBStateStoreProvider
private def validateAndTransitionState(transition: TRANSITION): Unit = {
val newState = transition match {
case UPDATE =>
stateMachine.verifyStamp(stamp)
state match {
case UPDATING => UPDATING
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
Expand All @@ -90,14 +95,18 @@ private[sql] class RocksDBStateStoreProvider
}
case ABORT =>
state match {
case UPDATING => ABORTED
case UPDATING =>
stateMachine.verifyStamp(stamp)
ABORTED
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot abort after committed")
case ABORTED => ABORTED
}
case COMMIT =>
state match {
case UPDATING => COMMITTED
case UPDATING =>
stateMachine.verifyStamp(stamp)
COMMITTED
case COMMITTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
"Cannot commit after committed")
case ABORTED => throw StateStoreErrors.stateStoreOperationOutOfOrder(
Expand All @@ -118,10 +127,14 @@ private[sql] class RocksDBStateStoreProvider
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit] {
_ =>
try {
abort()
if (state == UPDATING) {
abort()
}
} catch {
case NonFatal(e) =>
logWarning("Failed to abort state store", e)
} finally {
stateMachine.releaseStore(stamp, throwEx = false)
}
})

Expand Down Expand Up @@ -318,15 +331,18 @@ private[sql] class RocksDBStateStoreProvider
}

var checkpointInfo: Option[StateStoreCheckpointInfo] = None
private var storedMetrics: Option[RocksDBMetrics] = None

override def commit(): Long = synchronized {
validateState(List(UPDATING))

try {
verify(state == UPDATING, "Cannot commit after already committed or aborted")
val (newVersion, newCheckpointInfo) = rocksDB.commit()
checkpointInfo = Some(newCheckpointInfo)
storedMetrics = rocksDB.metricsOpt
validateAndTransitionState(COMMIT)
state = COMMITTED
stateMachine.releaseStore(stamp)

logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " +
log"for ${MDC(STATE_STORE_ID, id)}")
newVersion
Expand All @@ -342,6 +358,7 @@ private[sql] class RocksDBStateStoreProvider
log"for ${MDC(STATE_STORE_ID, id)}")
rocksDB.rollback()
validateAndTransitionState(ABORT)
stateMachine.releaseStore(stamp)
}
}

Expand Down Expand Up @@ -541,15 +558,26 @@ private[sql] class RocksDBStateStoreProvider

override def stateStoreId: StateStoreId = stateStoreId_

private lazy val stateMachine: RocksDBStateStoreProviderStateMachine =
new RocksDBStateStoreProviderStateMachine(stateStoreId, RocksDBConf(storeConf))

override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None)
new RocksDBStateStore(version)
val stamp = stateMachine.acquireStore()
try {
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
readOnly = false)
new RocksDBStateStore(version, stamp)
} catch {
case e: Throwable =>
stateMachine.releaseStore(stamp)
throw e
}
}
catch {
case e: SparkException
Expand All @@ -564,16 +592,58 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = {
override def getWriteStore(
readStore: ReadStateStore,
version: Long,
uniqueId: Option[String] = None): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
readOnly = true)
new RocksDBStateStore(version)
assert(version == readStore.version)
try {
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
readOnly = false)
readStore match {
case stateStore: RocksDBStateStore =>
stateStore
case _ =>
throw new IllegalArgumentException
}
} catch {
case e: Throwable =>
stateMachine.releaseStore(readStore.getReadStamp)
throw e
}
} catch {
case e: SparkException
if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) =>
throw e
case e: OutOfMemoryError =>
throw QueryExecutionErrors.notEnoughMemoryToLoadStore(
stateStoreId.toString,
"ROCKSDB_STORE_PROVIDER",
e)
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = {
try {
val stamp = stateMachine.acquireStore()
try {
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
readOnly = true)
new RocksDBStateStore(version, stamp)
} catch {
case e: Throwable =>
stateMachine.releaseStore(stamp)
throw e
}
}
catch {
case e: SparkException
Expand All @@ -590,6 +660,7 @@ private[sql] class RocksDBStateStoreProvider

override def doMaintenance(): Unit = {
try {
stateMachine.maintenanceStore()
rocksDB.doMaintenance()
} catch {
// SPARK-46547 - Swallow non-fatal exception in maintenance task to avoid deadlock between
Expand All @@ -601,6 +672,7 @@ private[sql] class RocksDBStateStoreProvider
}

override def close(): Unit = {
stateMachine.closeStore()
rocksDB.close()
}

Expand Down Expand Up @@ -657,8 +729,15 @@ private[sql] class RocksDBStateStoreProvider
if (endVersion < snapshotVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.loadFromSnapshot(snapshotVersion, endVersion)
new RocksDBStateStore(endVersion)
val stamp = stateMachine.acquireStore()
try {
rocksDB.loadFromSnapshot(snapshotVersion, endVersion)
new RocksDBStateStore(endVersion, stamp)
} catch {
case e: Throwable =>
stateMachine.releaseStore(stamp)
throw e
}
}
catch {
case e: OutOfMemoryError =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming.state

import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
import javax.annotation.concurrent.GuardedBy

import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryExecutionErrors

class RocksDBStateStoreProviderStateMachine(
stateStoreId: StateStoreId,
rocksDBConf: RocksDBConf) extends Logging {

private sealed trait STATE
private case object RELEASED extends STATE
private case object ACQUIRED extends STATE
private case object CLOSED extends STATE

private sealed abstract class TRANSITION(name: String) {
override def toString: String = name
}
private case object LOAD extends TRANSITION("load")
private case object RELEASE extends TRANSITION("release")
private case object CLOSE extends TRANSITION("close")
private case object MAINTENANCE extends TRANSITION("maintenance")

private val instanceLock = new Object()
@GuardedBy("instanceLock")
private var state: STATE = RELEASED
@GuardedBy("instanceLock")
private var acquiredThreadInfo: AcquiredThreadInfo = _

// Can be read without holding any locks, but should only be updated when
// instanceLock is held.
// -1 indicates that the store is not locked.
private[sql] val currentValidStamp = new AtomicLong(-1L)
@GuardedBy("instanceLock")
private var lastValidStamp: Long = 0L

// Instance lock must be held.
private def incAndGetStamp: Long = {
lastValidStamp += 1
currentValidStamp.set(lastValidStamp)
lastValidStamp
}

// Instance lock must be held.
private def awaitNotLocked(transition: TRANSITION): Unit = {
val waitStartTime = System.nanoTime()
def timeWaitedMs = {
val elapsedNanos = System.nanoTime() - waitStartTime
// Convert from nanoseconds to milliseconds
TimeUnit.MILLISECONDS.convert(elapsedNanos, TimeUnit.NANOSECONDS)
}
while (state == ACQUIRED && timeWaitedMs < rocksDBConf.lockAcquireTimeoutMs) {
instanceLock.wait(10)
}
if (state == ACQUIRED) {
val newAcquiredThreadInfo = AcquiredThreadInfo()
val stackTraceOutput = acquiredThreadInfo.threadRef.get.get.getStackTrace.mkString("\n")
val loggingId = s"StateStoreId(opId=${stateStoreId.operatorId}," +
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
throw QueryExecutionErrors.unreleasedThreadError(loggingId, transition.toString,
newAcquiredThreadInfo.toString(), acquiredThreadInfo.toString(), timeWaitedMs,
stackTraceOutput)
}
}

/**
* Returns oldState, newState.
* Throws error if transition is illegal.
* MUST be called for every StateStoreProvider method.
* Caller MUST hold instance lock.
*/
private def validateAndTransitionState(transition: TRANSITION): (STATE, STATE) = {
val oldState = state
val newState = transition match {
case LOAD =>
oldState match {
case RELEASED => ACQUIRED
case ACQUIRED => throw new IllegalStateException("Cannot lock when state is LOCKED")
case CLOSED => throw new IllegalStateException("Cannot lock when state is CLOSED")
}
case RELEASE =>
oldState match {
case RELEASED => throw new IllegalStateException("Cannot unlock when state is UNLOCKED")
case ACQUIRED => RELEASED
case CLOSED => throw new IllegalStateException("Cannot unlock when state is CLOSED")
}
case CLOSE =>
oldState match {
case RELEASED => CLOSED
case ACQUIRED => throw new IllegalStateException("Cannot closed when state is LOCKED")
case CLOSED => CLOSED
}
case MAINTENANCE =>
oldState match {
case RELEASED => RELEASED
case ACQUIRED => ACQUIRED
case CLOSED => throw new IllegalStateException("Cannot do maintenance when state is" +
"CLOSED")
}
}
state = newState
if (newState == ACQUIRED) {
acquiredThreadInfo = AcquiredThreadInfo()
}
(oldState, newState)
}

def verifyStamp(stamp: Long): Unit = {
if (stamp != currentValidStamp.get()) {
throw new IllegalStateException(s"Invalid stamp $stamp, " +
s"currentStamp: ${currentValidStamp.get()}")
}
}

// Returns whether store successfully released
def releaseStore(stamp: Long, throwEx: Boolean = true): Boolean = instanceLock.synchronized {
if (!currentValidStamp.compareAndSet(stamp, -1L)) {
if (throwEx) {
throw new IllegalStateException("Invalid stamp for release")
} else {
return false
}
}
validateAndTransitionState(RELEASE)
true
}

def acquireStore(): Long = instanceLock.synchronized {
awaitNotLocked(LOAD)
validateAndTransitionState(LOAD)
incAndGetStamp
}

def maintenanceStore(): Unit = instanceLock.synchronized {
validateAndTransitionState(MAINTENANCE)
}

def closeStore(): Unit = instanceLock.synchronized {
awaitNotLocked(CLOSE)
validateAndTransitionState(CLOSE)
}
}
Loading
Loading