Skip to content

Commit 4999469

Browse files
ericm-dbHeartSaVioR
authored andcommitted
[SPARK-48849][SS] Create OperatorStateMetadataV2 for the TransformWithStateExec operator
### What changes were proposed in this pull request? Introducing the OperatorStateMetadataV2 format that integrates with the TransformWithStateExec operator. This is used to keep information about the TWS operator, will be used to enforce invariants in between query runs. Each OperatorStateMetadataV2 has a pointer to the StateSchemaV3 file for the corresponding operator. Will introduce purging in this PR: #47286 ### Why are the changes needed? This is needed for State Metadata integration with the TransformWithState operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Added unit tests to StateStoreSuite and TransformWithStateSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #47445 from ericm-db/metadata-v2. Authored-by: Eric Marnadi <eric.marnadi@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent cf95e75 commit 4999469

File tree

10 files changed

+508
-102
lines changed

10 files changed

+508
-102
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionRead
3232
import org.apache.spark.sql.execution.datasources.v2.state.StateDataSourceErrors
3333
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.PATH
3434
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
35-
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1}
35+
import org.apache.spark.sql.execution.streaming.state.{OperatorInfoV1, OperatorStateMetadata, OperatorStateMetadataReader, OperatorStateMetadataV1, OperatorStateMetadataV2, StateStoreMetadataV1}
3636
import org.apache.spark.sql.sources.DataSourceRegister
3737
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructType}
3838
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -46,6 +46,7 @@ case class StateMetadataTableEntry(
4646
numPartitions: Int,
4747
minBatchId: Long,
4848
maxBatchId: Long,
49+
operatorPropertiesJson: String,
4950
numColsPrefixKey: Int) {
5051
def toRow(): InternalRow = {
5152
new GenericInternalRow(
@@ -55,6 +56,7 @@ case class StateMetadataTableEntry(
5556
numPartitions,
5657
minBatchId,
5758
maxBatchId,
59+
UTF8String.fromString(operatorPropertiesJson),
5860
numColsPrefixKey))
5961
}
6062
}
@@ -68,6 +70,7 @@ object StateMetadataTableEntry {
6870
.add("numPartitions", IntegerType)
6971
.add("minBatchId", LongType)
7072
.add("maxBatchId", LongType)
73+
.add("operatorProperties", StringType)
7174
}
7275
}
7376

@@ -188,29 +191,59 @@ class StateMetadataPartitionReader(
188191
} else Array.empty
189192
}
190193

191-
private def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
194+
// Need this to be accessible from IncrementalExecution for the planning rule.
195+
private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
192196
val stateDir = new Path(checkpointLocation, "state")
193197
val opIds = fileManager
194198
.list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => pathToLong(f.getPath)).sorted
195199
opIds.map { opId =>
196-
new OperatorStateMetadataReader(new Path(stateDir, opId.toString), hadoopConf).read()
200+
val operatorIdPath = new Path(stateDir, opId.toString)
201+
// check if OperatorStateMetadataV2 path exists, if it does, read it
202+
// otherwise, fall back to OperatorStateMetadataV1
203+
val operatorStateMetadataV2Path = OperatorStateMetadataV2.metadataDirPath(operatorIdPath)
204+
val operatorStateMetadataVersion = if (fileManager.exists(operatorStateMetadataV2Path)) {
205+
2
206+
} else {
207+
1
208+
}
209+
OperatorStateMetadataReader.createReader(
210+
operatorIdPath, hadoopConf, operatorStateMetadataVersion).read() match {
211+
case Some(metadata) => metadata
212+
case None => OperatorStateMetadataV1(OperatorInfoV1(opId, null),
213+
Array(StateStoreMetadataV1(null, -1, -1)))
214+
}
197215
}
198216
}
199217

200218
private[sql] lazy val stateMetadata: Iterator[StateMetadataTableEntry] = {
201219
allOperatorStateMetadata.flatMap { operatorStateMetadata =>
202-
require(operatorStateMetadata.version == 1)
203-
val operatorStateMetadataV1 = operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1]
204-
operatorStateMetadataV1.stateStoreInfo.map { stateStoreMetadata =>
205-
StateMetadataTableEntry(operatorStateMetadataV1.operatorInfo.operatorId,
206-
operatorStateMetadataV1.operatorInfo.operatorName,
207-
stateStoreMetadata.storeName,
208-
stateStoreMetadata.numPartitions,
209-
if (batchIds.nonEmpty) batchIds.head else -1,
210-
if (batchIds.nonEmpty) batchIds.last else -1,
211-
stateStoreMetadata.numColsPrefixKey
212-
)
220+
require(operatorStateMetadata.version == 1 || operatorStateMetadata.version == 2)
221+
operatorStateMetadata match {
222+
case v1: OperatorStateMetadataV1 =>
223+
v1.stateStoreInfo.map { stateStoreMetadata =>
224+
StateMetadataTableEntry(v1.operatorInfo.operatorId,
225+
v1.operatorInfo.operatorName,
226+
stateStoreMetadata.storeName,
227+
stateStoreMetadata.numPartitions,
228+
if (batchIds.nonEmpty) batchIds.head else -1,
229+
if (batchIds.nonEmpty) batchIds.last else -1,
230+
null,
231+
stateStoreMetadata.numColsPrefixKey
232+
)
233+
}
234+
case v2: OperatorStateMetadataV2 =>
235+
v2.stateStoreInfo.map { stateStoreMetadata =>
236+
StateMetadataTableEntry(v2.operatorInfo.operatorId,
237+
v2.operatorInfo.operatorName,
238+
stateStoreMetadata.storeName,
239+
stateStoreMetadata.numPartitions,
240+
if (batchIds.nonEmpty) batchIds.head else -1,
241+
if (batchIds.nonEmpty) batchIds.last else -1,
242+
v2.operatorPropertiesJson,
243+
-1 // numColsPrefixKey is not available in OperatorStateMetadataV2
244+
)
245+
}
246+
}
213247
}
214-
}
215-
}.iterator
248+
}.iterator
216249
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadat
3737
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
3838
import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
3939
import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
40-
import org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataWriter
40+
import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadataV1, OperatorStateMetadataV2, OperatorStateMetadataWriter}
4141
import org.apache.spark.sql.internal.SQLConf
4242
import org.apache.spark.sql.streaming.OutputMode
4343
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -208,13 +208,16 @@ class IncrementalExecution(
208208
}
209209
val schemaValidationResult = statefulOp.
210210
validateAndMaybeEvolveStateSchema(hadoopConf, currentBatchId, stateSchemaVersion)
211+
val stateSchemaPaths = schemaValidationResult.map(_.schemaPath)
211212
// write out the state schema paths to the metadata file
212213
statefulOp match {
213-
case stateStoreWriter: StateStoreWriter =>
214-
val metadata = stateStoreWriter.operatorStateMetadata()
215-
// TODO: [SPARK-48849] Populate metadata with stateSchemaPaths if metadata version is v2
216-
val metadataWriter = new OperatorStateMetadataWriter(new Path(
217-
checkpointLocation, stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf)
214+
case ssw: StateStoreWriter =>
215+
val metadata = ssw.operatorStateMetadata(stateSchemaPaths)
216+
val metadataWriter = OperatorStateMetadataWriter.createWriter(
217+
new Path(checkpointLocation, ssw.getStateInfo.operatorId.toString),
218+
hadoopConf,
219+
ssw.operatorStateMetadataVersion,
220+
Some(currentBatchId))
218221
metadataWriter.write(metadata)
219222
case _ =>
220223
}
@@ -456,8 +459,12 @@ class IncrementalExecution(
456459
val reader = new StateMetadataPartitionReader(
457460
new Path(checkpointLocation).getParent.toString,
458461
new SerializableConfiguration(hadoopConf))
459-
ret = reader.stateMetadata.map { metadataTableEntry =>
460-
metadataTableEntry.operatorId -> metadataTableEntry.operatorName
462+
val opMetadataList = reader.allOperatorStateMetadata
463+
ret = opMetadataList.map {
464+
case OperatorStateMetadataV1(operatorInfo, _) =>
465+
operatorInfo.operatorId -> operatorInfo.operatorName
466+
case OperatorStateMetadataV2(operatorInfo, _, _) =>
467+
operatorInfo.operatorId -> operatorInfo.operatorName
461468
}.toMap
462469
} catch {
463470
case e: Exception =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,12 @@ case class StreamingSymmetricHashJoinExec(
227227
private val stateStoreNames =
228228
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
229229

230-
override def operatorStateMetadata(): OperatorStateMetadata = {
230+
override def operatorStateMetadata(
231+
stateSchemaPaths: List[String] = List.empty): OperatorStateMetadata = {
231232
val info = getStateInfo
232233
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
233-
val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
234+
val stateStoreInfo =
235+
stateStoreNames.map(StateStoreMetadataV1(_, 0, info.numPartitions)).toArray
234236
OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
235237
}
236238

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
2121

2222
import org.apache.hadoop.conf.Configuration
2323
import org.apache.hadoop.fs.Path
24+
import org.json4s.JsonAST.JValue
25+
import org.json4s.JsonDSL._
26+
import org.json4s.JString
27+
import org.json4s.jackson.JsonMethods.{compact, render}
2428

2529
import org.apache.spark.broadcast.Broadcast
2630
import org.apache.spark.rdd.RDD
@@ -96,6 +100,8 @@ case class TransformWithStateExec(
96100
}
97101
}
98102

103+
override def operatorStateMetadataVersion: Int = 2
104+
99105
/**
100106
* We initialize this processor handle in the driver to run the init function
101107
* and fetch the schemas of the state variables initialized in this processor.
@@ -382,12 +388,47 @@ case class TransformWithStateExec(
382388
batchId: Long,
383389
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
384390
assert(stateSchemaVersion >= 3)
385-
val newColumnFamilySchemas = getColFamilySchemas()
391+
val newSchemas = getColFamilySchemas()
386392
val stateSchemaDir = stateSchemaDirPath()
387-
val stateSchemaFilePath = new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
388-
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
389-
newColumnFamilySchemas.values.toList, session.sessionState, stateSchemaVersion,
390-
schemaFilePath = Some(stateSchemaFilePath)))
393+
val newStateSchemaFilePath =
394+
new Path(stateSchemaDir, s"${batchId}_${UUID.randomUUID().toString}")
395+
val metadataPath = new Path(getStateInfo.checkpointLocation, s"${getStateInfo.operatorId}")
396+
val metadataReader = OperatorStateMetadataReader.createReader(
397+
metadataPath, hadoopConf, operatorStateMetadataVersion)
398+
val operatorStateMetadata = metadataReader.read()
399+
val oldStateSchemaFilePath: Option[Path] = operatorStateMetadata match {
400+
case Some(metadata) =>
401+
metadata match {
402+
case v2: OperatorStateMetadataV2 =>
403+
Some(new Path(v2.stateStoreInfo.head.stateSchemaFilePath))
404+
case _ => None
405+
}
406+
case None => None
407+
}
408+
List(StateSchemaCompatibilityChecker.
409+
validateAndMaybeEvolveStateSchema(getStateInfo, hadoopConf,
410+
newSchemas.values.toList, session.sessionState, stateSchemaVersion,
411+
storeName = StateStoreId.DEFAULT_STORE_NAME,
412+
oldSchemaFilePath = oldStateSchemaFilePath,
413+
newSchemaFilePath = Some(newStateSchemaFilePath)))
414+
}
415+
416+
/** Metadata of this stateful operator and its states stores. */
417+
override def operatorStateMetadata(
418+
stateSchemaPaths: List[String]): OperatorStateMetadata = {
419+
val info = getStateInfo
420+
val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
421+
// stateSchemaFilePath should be populated at this point
422+
val stateStoreInfo =
423+
Array(StateStoreMetadataV2(
424+
StateStoreId.DEFAULT_STORE_NAME, 0, info.numPartitions, stateSchemaPaths.head))
425+
426+
val operatorPropertiesJson: JValue =
427+
("timeMode" -> JString(timeMode.toString)) ~
428+
("outputMode" -> JString(outputMode.toString))
429+
430+
val json = compact(render(operatorPropertiesJson))
431+
OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json)
391432
}
392433

393434
private def stateSchemaDirPath(): Path = {

0 commit comments

Comments
 (0)