Skip to content

[SPARK-48589][SQL][SS] Add option snapshotStartBatchId and snapshotPartitionId to state data source #46944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6db0e3d
initial implementation
eason-yuchen-liu Jun 4, 2024
7dad0c1
Merge branch 'skipSnapshotAtBatch' of https://github.com/eason-yuchen…
eason-yuchen-liu Jun 4, 2024
2475173
add test cases for two options in HDFS state store
eason-yuchen-liu Jun 6, 2024
07267b5
allow rocksdb to reconstruct state from a specific checkpoint
eason-yuchen-liu Jun 7, 2024
9d902d7
test directly on the method instead of end to end
eason-yuchen-liu Jun 10, 2024
eddb3c7
Merge branch 'apache:master' into skipSnapshotAtBatch
eason-yuchen-liu Jun 10, 2024
1a3d20a
make sure test is stable
eason-yuchen-liu Jun 10, 2024
292ec5d
delete useless test files
eason-yuchen-liu Jun 10, 2024
aa337c1
add new test on partition not found error
eason-yuchen-liu Jun 11, 2024
dfa712e
clean up and format
eason-yuchen-liu Jun 11, 2024
4ebd078
move partition error
eason-yuchen-liu Jun 11, 2024
1656580
improve doc
eason-yuchen-liu Jun 11, 2024
61dea35
minor
eason-yuchen-liu Jun 11, 2024
5229152
support reading join states
eason-yuchen-liu Jun 12, 2024
4825215
address reviews by Wei partially
eason-yuchen-liu Jun 13, 2024
20e1b9c
address comments from Anish & Wei
eason-yuchen-liu Jun 13, 2024
9eb6c76
Merge branch 'master' into skipSnapshotAtBatch
eason-yuchen-liu Jun 13, 2024
4d4cd70
log StateSourceOptions optionally
eason-yuchen-liu Jun 13, 2024
1870b35
Merge branch 'skipSnapshotAtBatch' of https://github.com/eason-yuchen…
eason-yuchen-liu Jun 13, 2024
fe9cea1
address more comments from Anish
eason-yuchen-liu Jun 14, 2024
3f266c1
style
eason-yuchen-liu Jun 17, 2024
2eb6646
also update the name of StateTable
eason-yuchen-liu Jun 21, 2024
be30817
Reflect more comments from Anish
eason-yuchen-liu Jun 22, 2024
3ece6f2
resort error-conditions
eason-yuchen-liu Jun 22, 2024
ef9b095
create integration test against golden files
eason-yuchen-liu Jun 25, 2024
876256e
reflect comments from Jungtaek
eason-yuchen-liu Jun 25, 2024
1a23abb
refactor the code to isolate from current state stores used by stream…
eason-yuchen-liu Jun 25, 2024
97ee3ef
some naming and formatting comments from Anish and Jungtaek
eason-yuchen-liu Jun 26, 2024
23639f4
create new error for SupportsFineGrainedReplayFromSnapshot
eason-yuchen-liu Jun 26, 2024
40b6dc6
move error to StateStoreErrors
eason-yuchen-liu Jun 26, 2024
e15213e
rename to startVersion to snapshotVersion to make its function clear
eason-yuchen-liu Jun 27, 2024
42d952f
rename SupportsFineGrainedReplayFromSnapshot to SupportsFineGrainedRe…
eason-yuchen-liu Jun 27, 2024
6f1425d
reflect more comments from Jungtaek
eason-yuchen-liu Jun 27, 2024
4deb63e
throw the exception
eason-yuchen-liu Jun 27, 2024
d140708
provide the script to regenerate golden files
eason-yuchen-liu Jun 27, 2024
337785d
address comments from Anish
eason-yuchen-liu Jun 29, 2024
9dbe295
minor
eason-yuchen-liu Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
17 changes: 17 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@
"Error reading delta file <fileToRead> of <clazz>: <fileToRead> does not exist."
]
},
"CANNOT_READ_MISSING_SNAPSHOT_FILE" : {
"message" : [
"Error reading snapshot file <fileToRead> of <clazz>: <fileToRead> does not exist."
]
},
"CANNOT_READ_SNAPSHOT_FILE_KEY_SIZE" : {
"message" : [
"Error reading snapshot file <fileToRead> of <clazz>: key size cannot be <keySize>."
Expand All @@ -239,6 +244,11 @@
"Error reading streaming state file of <fileToRead> does not exist. If the stream job is restarted with a new or updated state operation, please create a new checkpoint location or clear the existing checkpoint location."
]
},
"SNAPSHOT_PARTITION_ID_NOT_FOUND" : {
"message" : [
"Partition id <snapshotPartitionId> not found for state of operator <operatorId> at <checkpointLocation>."
]
},
"UNCATEGORIZED" : {
"message" : [
""
Expand Down Expand Up @@ -3763,6 +3773,13 @@
],
"sqlState" : "42802"
},
"STATE_STORE_PROVIDER_DOES_NOT_SUPPORT_FINE_GRAINED_STATE_REPLAY" : {
"message" : [
"The given State Store Provider <inputClass> does not extend org.apache.spark.sql.execution.streaming.state.SupportsFineGrainedReplay.",
"Therefore, it does not support option snapshotStartBatchId in state data source."
],
"sqlState" : "42K06"
},
"STATE_STORE_UNSUPPORTED_OPERATION" : {
"message" : [
"<operationType> operation not supported with <entity>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,16 @@ case class StateSourceOptions(
batchId: Long,
operatorId: Int,
storeName: String,
joinSide: JoinSideValues) {
joinSide: JoinSideValues,
snapshotStartBatchId: Option[Long],
snapshotPartitionId: Option[Int]) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE)

override def toString: String = {
s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " +
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide)"
s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
s"snapshotStartBatchId=${snapshotStartBatchId.getOrElse("None")}, " +
s"snapshotPartitionId=${snapshotPartitionId.getOrElse("None")})"
}
}

Expand All @@ -131,6 +135,8 @@ object StateSourceOptions extends DataSourceOptions {
val OPERATOR_ID = newOption("operatorId")
val STORE_NAME = newOption("storeName")
val JOIN_SIDE = newOption("joinSide")
val SNAPSHOT_START_BATCH_ID = newOption("snapshotStartBatchId")
val SNAPSHOT_PARTITION_ID = newOption("snapshotPartitionId")

object JoinSideValues extends Enumeration {
type JoinSideValues = Value
Expand Down Expand Up @@ -190,7 +196,30 @@ object StateSourceOptions extends DataSourceOptions {
throw StateDataSourceErrors.conflictOptions(Seq(JOIN_SIDE, STORE_NAME))
}

StateSourceOptions(resolvedCpLocation, batchId, operatorId, storeName, joinSide)
val snapshotStartBatchId = Option(options.get(SNAPSHOT_START_BATCH_ID)).map(_.toLong)
if (snapshotStartBatchId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_START_BATCH_ID)
} else if (snapshotStartBatchId.exists(_ > batchId)) {
throw StateDataSourceErrors.invalidOptionValue(
SNAPSHOT_START_BATCH_ID, s"value should be less than or equal to $batchId")
}

val snapshotPartitionId = Option(options.get(SNAPSHOT_PARTITION_ID)).map(_.toInt)
if (snapshotPartitionId.exists(_ < 0)) {
throw StateDataSourceErrors.invalidOptionValueIsNegative(SNAPSHOT_PARTITION_ID)
}

// both snapshotPartitionId and snapshotStartBatchId are required at the same time, because
// each partition may have different checkpoint status
if (snapshotPartitionId.isDefined && snapshotStartBatchId.isEmpty) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_START_BATCH_ID)
} else if (snapshotPartitionId.isEmpty && snapshotStartBatchId.isDefined) {
throw StateDataSourceErrors.requiredOptionUnspecified(SNAPSHOT_PARTITION_ID)
}

StateSourceOptions(
resolvedCpLocation, batchId, operatorId, storeName,
joinSide, snapshotStartBatchId, snapshotPartitionId)
}

private def resolvedCheckpointLocation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -93,7 +93,19 @@ class StatePartitionReader(
}

private lazy val store: ReadStateStore = {
provider.getReadStore(partition.sourceOptions.batchId + 1)
partition.sourceOptions.snapshotStartBatchId match {
case None => provider.getReadStore(partition.sourceOptions.batchId + 1)

case Some(snapshotStartBatchId) =>
if (!provider.isInstanceOf[SupportsFineGrainedReplay]) {
throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
provider.getClass.toString)
}
provider.asInstanceOf[SupportsFineGrainedReplay]
.replayReadStateFromSnapshot(
snapshotStartBatchId + 1,
partition.sourceOptions.batchId + 1)
}
}

private lazy val iter: Iterator[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, StateStoreErrors}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -81,9 +81,20 @@ class StateScan(
assert((tail - head + 1) == partitionNums.length,
s"No continuous partitions in state: ${partitionNums.mkString("Array(", ", ", ")")}")

partitionNums.map {
pn => new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray
sourceOptions.snapshotPartitionId match {
case None => partitionNums.map { pn =>
new StateStoreInputPartition(pn, queryId, sourceOptions)
}.toArray

case Some(snapshotPartitionId) =>
if (partitionNums.contains(snapshotPartitionId)) {
Array(new StateStoreInputPartition(snapshotPartitionId, queryId, sourceOptions))
} else {
throw StateStoreErrors.stateStoreSnapshotPartitionNotFound(
snapshotPartitionId, sourceOptions.operatorId,
sourceOptions.stateCheckpointLocation.toString)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ class StateTable(
}

override def name(): String = {
val desc = s"StateTable " +
var desc = s"StateTable " +
s"[stateCkptLocation=${sourceOptions.stateCheckpointLocation}]" +
s"[batchId=${sourceOptions.batchId}][operatorId=${sourceOptions.operatorId}]" +
s"[storeName=${sourceOptions.storeName}]"

if (sourceOptions.joinSide != JoinSideValues.none) {
desc + s"[joinSide=${sourceOptions.joinSide}]"
} else {
desc
desc += s"[joinSide=${sourceOptions.joinSide}]"
}
if (sourceOptions.snapshotStartBatchId.isDefined) {
desc += s"[snapshotStartBatchId=${sourceOptions.snapshotStartBatchId}]"
}
if (sourceOptions.snapshotPartitionId.isDefined) {
desc += s"[snapshotPartitionId=${sourceOptions.snapshotPartitionId}]"
}
desc
}

override def capabilities(): util.Set[TableCapability] = CAPABILITY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ class StreamStreamJoinStatePartitionReader(
partitionId = partition.partition,
formatVersion,
skippedNullValueCount = None,
useStateStoreCoordinator = false
useStateStoreCoordinator = false,
snapshotStartVersion = partition.sourceOptions.snapshotStartBatchId.map(_ + 1)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ import org.apache.spark.util.ArrayImplicits._
* to ensure re-executed RDD operations re-apply updates on the correct past version of the
* store.
*/
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging
with SupportsFineGrainedReplay {

private val providerName = "HDFSBackedStateStoreProvider"

Expand Down Expand Up @@ -683,6 +684,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
}
}

/**
* Try to read the snapshot file. If the snapshot file is not available, return [[None]].
*
* @param version the version of the snapshot file
*/
private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
val fileToRead = snapshotFile(version)
val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
Expand Down Expand Up @@ -883,4 +889,93 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
throw new IllegalStateException(msg)
}
}

/**
* Get the state store of endVersion by applying delta files on the snapshot of snapshotVersion.
* If snapshot for snapshotVersion does not exist, an error will be thrown.
*
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedStateStore]]
*/
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = {
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for update")
new HDFSBackedStateStore(endVersion, newMap)
}

/**
* Get the state store of endVersion for reading by applying delta files on the snapshot of
* snapshotVersion. If snapshot for snapshotVersion does not exist, an error will be thrown.
*
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedReadStateStore]]
*/
override def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long):
ReadStateStore = {
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, endVersion)} of " +
log"${MDC(LogKeys.STATE_STORE_PROVIDER, HDFSBackedStateStoreProvider.this)} for read-only")
new HDFSBackedReadStateStore(endVersion, newMap)
}

/**
* Construct the state map at endVersion from snapshot of version snapshotVersion.
* Returns a new [[HDFSBackedStateStoreMap]]
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
private def replayLoadedMapFromSnapshot(snapshotVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = synchronized {
try {
if (snapshotVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion)
}
if (endVersion < snapshotVersion) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}

val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
newMap.putAll(constructMapFromSnapshot(snapshotVersion, endVersion))

newMap
}
catch {
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
}

private def constructMapFromSnapshot(snapshotVersion: Long, endVersion: Long):
HDFSBackedStateStoreMap = {
val (result, elapsedMs) = Utils.timeTakenMs {
val startVersionMap = synchronized { Option(loadedMaps.get(snapshotVersion)) } match {
case Some(value) => Option(value)
case None => readSnapshotFile(snapshotVersion)
}
if (startVersionMap.isEmpty) {
throw StateStoreErrors.stateStoreSnapshotFileNotFound(
snapshotFile(snapshotVersion).toString, toString())
}

// Load all the deltas from the version after the start version up to the end version.
val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
resultMap.putAll(startVersionMap.get)
for (deltaVersion <- snapshotVersion + 1 to endVersion) {
updateFromDeltaFile(deltaVersion, resultMap)
}

resultMap
}

logDebug(s"Loading snapshot at version $snapshotVersion and apply delta files to version " +
s"$endVersion takes $elapsedMs ms.")

result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,80 @@ class RocksDB(
this
}

/**
* Load from the start snapshot version and apply all the changelog records to reach the
* end version. Note that this will copy all the necessary files from DFS to local disk as needed,
* and possibly restart the native RocksDB instance.
*
* @param snapshotVersion version of the snapshot to start with
* @param endVersion end version
* @return A RocksDB instance loaded with the state endVersion replayed from snapshotVersion.
* Note that the instance will be read-only since this method is only used in State Data
* Source.
*/
def loadFromSnapshot(snapshotVersion: Long, endVersion: Long): RocksDB = {
assert(snapshotVersion >= 0 && endVersion >= snapshotVersion)
acquire(LoadStore)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lock release path is still the same right ? i assume we release on an abort ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am copying the existing implementation. Any changes needed here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea - Im guessing the unlock happens in the end as part of an abort within the state data source reader

recordedMetrics = None
logInfo(
log"Loading snapshot at version ${MDC(LogKeys.VERSION_NUM, snapshotVersion)} and apply " +
log"changelog files to version ${MDC(LogKeys.VERSION_NUM, endVersion)}.")
try {
replayFromCheckpoint(snapshotVersion, endVersion)

logInfo(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

log"Loaded snapshot at version ${MDC(LogKeys.VERSION_NUM, snapshotVersion)} and apply " +
log"changelog files to version ${MDC(LogKeys.VERSION_NUM, endVersion)}.")
} catch {
case t: Throwable =>
loadedVersion = -1 // invalidate loaded data
throw t
}
this
}

/**
* Load from the start checkpoint version and apply all the changelog records to reach the
* end version.
* If the start version does not exist, it will throw an exception.
*
* @param snapshotVersion start checkpoint version
* @param endVersion end version
*/
private def replayFromCheckpoint(snapshotVersion: Long, endVersion: Long): Any = {
closeDB()
val metadata = fileManager.loadCheckpointFromDfs(snapshotVersion, workingDir)
loadedVersion = snapshotVersion

// reset last snapshot version
if (lastSnapshotVersion > snapshotVersion) {
// discard any newer snapshots
lastSnapshotVersion = 0L
latestSnapshot = None
}
openDB()

numKeysOnWritingVersion = if (!conf.trackTotalNumberOfRows) {
// we don't track the total number of rows - discard the number being track
-1L
} else if (metadata.numKeys < 0) {
// we track the total number of rows, but the snapshot doesn't have tracking number
// need to count keys now
countKeys()
} else {
metadata.numKeys
}
if (loadedVersion != endVersion) replayChangelog(endVersion)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see user-friendly error message when changelog file does not exist. Let's say, users may be actually not using changelog checkpointing and somehow mislead that it's supported. Providing FileNotFoundException to them does not give an hint what is possibly not correct - smart user may just notice what is wrong, but better to be user-friendly, and also be a part of error class framework.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error users will get now when the changelog does not exist is:

Cause: org.apache.spark.SparkException: [CANNOT_LOAD_STATE_STORE.CANNOT_READ_STREAMING_STATE_FILE] An error occurred during loading state. Error reading streaming state file of <checkpointLocation> does not exist.
If the stream job is restarted with a new or updated state operation, please create a new checkpoint location or clear the existing checkpoint location. SQLSTATE: 58030

It does not have its own error class so I think we should put this to further tasks: Put this error to its own error class and catch it here to remind user of possible cause.

// After changelog replay the numKeysOnWritingVersion will be updated to
// the correct number of keys in the loaded version.
numKeysOnLoadedVersion = numKeysOnWritingVersion
fileManagerMetrics = fileManager.latestLoadCheckpointMetrics

if (conf.resetStatsOnLoad) {
nativeStats.reset
}
}

/**
* Replay change log from the loaded version to the target version.
*/
Expand Down
Loading