Skip to content

Commit f78dcb9

Browse files
jingz-dbsweisdb
authored andcommitted
[SPARK-47363][SS] Initial State without state reader implementation for State API v2
### What changes were proposed in this pull request? This PR adds support for users to provide a Dataframe that can be used to instantiate state for the query in the first batch for arbitrary state API v2. Note that populating the initial state will only happen for the first batch of the new streaming query. Trying to re-initialize state for the same grouping key will result in an error. ### Why are the changes needed? These changes are needed to support initial state. The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939 ### Does this PR introduce _any_ user-facing change? Yes. This PR introduces a new function: ``` def transformWithState( statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeoutMode: TimeoutMode, outputMode: OutputMode, initialState: KeyValueGroupedDataset[K, S]): Dataset[U] ``` ### How was this patch tested? Unit tests in `TransformWithStateWithInitialStateSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#45467 from jingz-db/initial-state-state-v2. Lead-authored-by: jingz-db <jing.zhan@databricks.com> Co-authored-by: Jing Zhan <135738831+jingz-db@users.noreply.github.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent 0177327 commit f78dcb9

File tree

12 files changed

+661
-69
lines changed

12 files changed

+661
-69
lines changed

common/utils/src/main/resources/error/error-classes.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,6 +3553,12 @@
35533553
],
35543554
"sqlState" : "42802"
35553555
},
3556+
"STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
3557+
"message" : [
3558+
"Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=<groupingKey>."
3559+
],
3560+
"sqlState" : "42802"
3561+
},
35563562
"STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
35573563
"message" : [
35583564
"Failed to create column family with unsupported starting character and name=<colFamilyName>."

docs/sql-error-conditions.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,12 @@ Failed to perform stateful processor operation=`<operationType>` with invalid ha
21622162

21632163
Failed to perform stateful processor operation=`<operationType>` with invalid timeoutMode=`<timeoutMode>`
21642164

2165+
### STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY
2166+
2167+
[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
2168+
2169+
Cannot re-initialize state on the same grouping key during initial state handling for stateful processor. Invalid grouping key=`<groupingKey>`.
2170+
21652171
### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS
21662172

21672173
[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,22 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
9191
statefulProcessorHandle
9292
}
9393
}
94+
95+
/**
96+
* Stateful processor with support for specifying initial state.
97+
* Accepts a user-defined type as initial state to be initialized in the first batch.
98+
* This can be used for starting a new streaming query with existing state from a
99+
* previous streaming query.
100+
*/
101+
@Experimental
102+
@Evolving
103+
trait StatefulProcessorWithInitialState[K, I, O, S] extends StatefulProcessor[K, I, O] {
104+
105+
/**
106+
* Function that will be invoked only in the first batch for users to process initial states.
107+
*
108+
* @param key - grouping key
109+
* @param initialState - A row in the initial state to be processed
110+
*/
111+
def handleInitialState(key: K, initialState: S): Unit
112+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,46 @@ object TransformWithState {
588588
outputMode,
589589
keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
590590
CatalystSerde.generateObjAttr[U],
591-
child
591+
child,
592+
hasInitialState = false,
593+
// the following parameters will not be used in physical plan if hasInitialState = false
594+
initialStateGroupingAttrs = groupingAttributes,
595+
initialStateDataAttrs = dataAttributes,
596+
initialStateDeserializer =
597+
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
598+
initialState = LocalRelation(encoderFor[K].schema) // empty data set
599+
)
600+
CatalystSerde.serialize[U](mapped)
601+
}
602+
603+
// This apply() is to invoke TransformWithState object with hasInitialState set to true
604+
def apply[K: Encoder, V: Encoder, U: Encoder, S: Encoder](
605+
groupingAttributes: Seq[Attribute],
606+
dataAttributes: Seq[Attribute],
607+
statefulProcessor: StatefulProcessor[K, V, U],
608+
timeoutMode: TimeoutMode,
609+
outputMode: OutputMode,
610+
child: LogicalPlan,
611+
initialStateGroupingAttrs: Seq[Attribute],
612+
initialStateDataAttrs: Seq[Attribute],
613+
initialState: LogicalPlan): LogicalPlan = {
614+
val keyEncoder = encoderFor[K]
615+
val mapped = new TransformWithState(
616+
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
617+
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
618+
groupingAttributes,
619+
dataAttributes,
620+
statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
621+
timeoutMode,
622+
outputMode,
623+
keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
624+
CatalystSerde.generateObjAttr[U],
625+
child,
626+
hasInitialState = true,
627+
initialStateGroupingAttrs,
628+
initialStateDataAttrs,
629+
UnresolvedDeserializer(encoderFor[S].deserializer, initialStateDataAttrs),
630+
initialState
592631
)
593632
CatalystSerde.serialize[U](mapped)
594633
}
@@ -604,10 +643,18 @@ case class TransformWithState(
604643
outputMode: OutputMode,
605644
keyEncoder: ExpressionEncoder[Any],
606645
outputObjAttr: Attribute,
607-
child: LogicalPlan) extends UnaryNode with ObjectProducer {
646+
child: LogicalPlan,
647+
hasInitialState: Boolean = false,
648+
initialStateGroupingAttrs: Seq[Attribute],
649+
initialStateDataAttrs: Seq[Attribute],
650+
initialStateDeserializer: Expression,
651+
initialState: LogicalPlan) extends BinaryNode with ObjectProducer {
608652

609-
override protected def withNewChildInternal(newChild: LogicalPlan): TransformWithState =
610-
copy(child = newChild)
653+
override def left: LogicalPlan = child
654+
override def right: LogicalPlan = initialState
655+
override protected def withNewChildrenInternal(
656+
newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
657+
copy(child = newLeft, initialState = newRight)
611658
}
612659

613660
/** Factory for constructing new `FlatMapGroupsInR` nodes. */

sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.execution.QueryExecution
2727
import org.apache.spark.sql.expressions.ReduceAggregator
2828
import org.apache.spark.sql.internal.TypedAggUtils
29-
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode}
29+
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}
3030

3131
/**
3232
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
@@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql](
676676
)
677677
}
678678

679+
/**
680+
* (Scala-specific)
681+
* Invokes methods defined in the stateful processor used in arbitrary state API v2.
682+
* Functions as the function above, but with additional initial state.
683+
*
684+
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
685+
* @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
686+
* @param statefulProcessor Instance of statefulProcessor whose functions will
687+
* be invoked by the operator.
688+
* @param timeoutMode The timeout mode of the stateful processor.
689+
* @param outputMode The output mode of the stateful processor. Defaults to APPEND mode.
690+
* @param initialState User provided initial state that will be used to initiate state for
691+
* the query in the first batch.
692+
*
693+
*/
694+
private[sql] def transformWithState[U: Encoder, S: Encoder](
695+
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
696+
timeoutMode: TimeoutMode,
697+
outputMode: OutputMode,
698+
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
699+
Dataset[U](
700+
sparkSession,
701+
TransformWithState[K, V, U, S](
702+
groupingAttributes,
703+
dataAttributes,
704+
statefulProcessor,
705+
timeoutMode,
706+
outputMode,
707+
child = logicalPlan,
708+
initialState.groupingAttributes,
709+
initialState.dataAttributes,
710+
initialState.queryExecution.analyzed
711+
)
712+
)
713+
}
714+
679715
/**
680716
* (Scala-specific)
681717
* Reduces the elements of each group of data using the specified binary function.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
752752
case TransformWithState(
753753
keyDeserializer, valueDeserializer, groupingAttributes,
754754
dataAttributes, statefulProcessor, timeoutMode, outputMode,
755-
keyEncoder, outputAttr, child) =>
755+
keyEncoder, outputAttr, child, hasInitialState,
756+
initialStateGroupingAttrs, initialStateDataAttrs,
757+
initialStateDeserializer, initialState) =>
756758
val execPlan = TransformWithStateExec(
757759
keyDeserializer,
758760
valueDeserializer,
@@ -767,7 +769,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
767769
batchTimestampMs = None,
768770
eventTimeWatermarkForLateEvents = None,
769771
eventTimeWatermarkForEviction = None,
770-
planLater(child))
772+
planLater(child),
773+
isStreaming = true,
774+
hasInitialState,
775+
initialStateGroupingAttrs,
776+
initialStateDataAttrs,
777+
initialStateDeserializer,
778+
planLater(initialState))
771779
execPlan :: Nil
772780
case _ =>
773781
Nil
@@ -918,10 +926,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
918926
) :: Nil
919927
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
920928
dataAttributes, statefulProcessor, timeoutMode, outputMode, keyEncoder,
921-
outputObjAttr, child) =>
929+
outputObjAttr, child, hasInitialState,
930+
initialStateGroupingAttrs, initialStateDataAttrs,
931+
initialStateDeserializer, initialState) =>
922932
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer, valueDeserializer,
923933
groupingAttributes, dataAttributes, statefulProcessor, timeoutMode, outputMode,
924-
keyEncoder, outputObjAttr, planLater(child)) :: Nil
934+
keyEncoder, outputObjAttr, planLater(child), hasInitialState,
935+
initialStateGroupingAttrs, initialStateDataAttrs,
936+
initialStateDeserializer, planLater(initialState)) :: Nil
925937

926938
case _: FlatMapGroupsInPandasWithState =>
927939
// TODO(SPARK-40443): support applyInPandasWithState in batch query

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,13 @@ class IncrementalExecution(
268268
)
269269

270270
case t: TransformWithStateExec =>
271+
val hasInitialState = (currentBatchId == 0L && t.hasInitialState)
271272
t.copy(
272273
stateInfo = Some(nextStatefulOperationStateInfo()),
273274
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
274275
eventTimeWatermarkForLateEvents = None,
275-
eventTimeWatermarkForEviction = None
276+
eventTimeWatermarkForEviction = None,
277+
hasInitialState = hasInitialState
276278
)
277279

278280
case m: FlatMapGroupsInPandasWithStateExec =>

0 commit comments

Comments
 (0)