Skip to content

Commit

Permalink
[SPARK-50967][SS] Add option to skip emitting initial state keys with…
Browse files Browse the repository at this point in the history
…in the FMGWS operator

### What changes were proposed in this pull request?
Add option to skip emitting initial state keys within the FMGWS operator

### Why are the changes needed?
Without this change, user does not have an easy way of filtering initial state rows emitted as part of the first batch.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added unit tests

```
[info] Run completed in 15 seconds, 349 milliseconds.
[info] Total number of tests run: 4
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[success] Total time: 51 s, completed Jan 23, 2025, 4:26:03 PM
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#49632 from anishshri-db/task/SPARK-50967-fix.

Authored-by: Anish Shrigondekar <anish.shrigondekar@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Feb 1, 2025
1 parent bf442c5 commit cfc3f1f
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,15 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS =
buildConf("spark.sql.streaming.flatMapGroupsWithState.skipEmittingInitialStateKeys")
.internal()
.doc("When true, the flatMapGroupsWithState operation in a streaming query will not emit " +
"results for the initial state keys of each group.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
.doc("The default location for storing checkpoint data for streaming queries.")
.version("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) =>
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val skipEmittingInitialStateKeys =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS)
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
planLater(initialState), hasInitialState, planLater(child)
planLater(initialState), hasInitialState, skipEmittingInitialStateKeys, planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -829,7 +831,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val execPlan = python.FlatMapGroupsInPandasWithStateExec(
func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout,
batchTimestampMs = None, eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None, planLater(child)
eventTimeWatermarkForEviction = None,
skipEmittingInitialStateKeys = false,
planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -954,10 +958,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
val skipEmittingInitialStateKeys =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS)
FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
hasInitialState, skipEmittingInitialStateKeys, planLater(initialState), planLater(child)
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeMode, outputMode, keyEncoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import org.apache.spark.util.CompletionIterator
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsInPandasWithStateExec(
Expand All @@ -64,6 +65,7 @@ case class FlatMapGroupsInPandasWithStateExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {

// TODO(SPARK-40444): Add the support of initial state.
Expand Down Expand Up @@ -137,7 +139,8 @@ case class FlatMapGroupsInPandasWithStateExec(

override def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
initStateIter: Iterator[InternalRow],
skipEmittingInitialStateKeys: Boolean): Iterator[InternalRow] = {
throw SparkUnsupportedOperationException()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ trait FlatMapGroupsWithStateExecBase
protected val initialStateDataAttrs: Seq[Attribute]
protected val initialState: SparkPlan
protected val hasInitialState: Boolean
protected val skipEmittingInitialStateKeys: Boolean

val stateInfo: Option[StatefulOperatorStateInfo]
protected val stateEncoder: ExpressionEncoder[Any]
Expand Down Expand Up @@ -145,7 +146,8 @@ trait FlatMapGroupsWithStateExecBase

val processedOutputIterator = initialStateIterOption match {
case Some(initStateIter) if initStateIter.hasNext =>
processor.processNewDataWithInitialState(filteredIter, initStateIter)
processor.processNewDataWithInitialState(filteredIter, initStateIter,
skipEmittingInitialStateKeys)
case _ => processor.processNewData(filteredIter)
}

Expand Down Expand Up @@ -301,7 +303,8 @@ trait FlatMapGroupsWithStateExecBase
*/
def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]
initStateIter: Iterator[InternalRow],
skipEmittingInitialStateKeys: Boolean
): Iterator[InternalRow] = {

if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty
Expand All @@ -312,7 +315,8 @@ trait FlatMapGroupsWithStateExecBase
val groupedInitialStateIter =
GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output)

// Create a CoGroupedIterator that will group the two iterators together for every key group.
// Create a CoGroupedIterator that will group the two iterators together for every
// key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
Expand All @@ -326,12 +330,17 @@ trait FlatMapGroupsWithStateExecBase
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
// We apply the values for the key after applying the initial state.
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),

if (skipEmittingInitialStateKeys && valueRowIter.isEmpty) {
// If the user has specified to skip emitting the keys that only have initial state
// and no data, then we should not call the function for such keys.
Iterator.empty
} else {
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false
)
hasTimedOut = false)
}
}
}

Expand Down Expand Up @@ -388,6 +397,7 @@ trait FlatMapGroupsWithStateExecBase
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys
* @param child the physical plan for the underlying data
*/
case class FlatMapGroupsWithStateExec(
Expand All @@ -410,6 +420,7 @@ case class FlatMapGroupsWithStateExec(
eventTimeWatermarkForEviction: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan)
extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec {
import GroupStateImpl._
Expand Down Expand Up @@ -533,6 +544,7 @@ object FlatMapGroupsWithStateExec {
outputObjAttr: Attribute,
timeoutConf: GroupStateTimeout,
hasInitialState: Boolean,
skipEmittingInitialStateKeys: Boolean,
initialState: SparkPlan,
child: SparkPlan): SparkPlan = {
if (hasInitialState) {
Expand All @@ -541,27 +553,31 @@ object FlatMapGroupsWithStateExec {
case _ => false
}
val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray

// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)

// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
if (skipEmittingInitialStateKeys && values.isEmpty) {
Iterator.empty
} else {
// Check if there is only one state for every key.
var foundInitialStateForKey = false
val optionalStates = states.map { stateValue =>
if (foundInitialStateForKey) {
foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
stateValue
}.toArray

// Create group state object
val groupState = GroupStateImpl.createForStreaming(
optionalStates.headOption,
System.currentTimeMillis,
GroupStateImpl.NO_TIMESTAMP,
timeoutConf,
hasTimedOut = false,
watermarkPresent)

// Call user function with the state and values for this key
userFunc(keyRow, values, groupState)
}
}
CoGroupExec(
func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark),
RDDScanExec(g, emptyRdd, "rdd"),
hasInitialState,
false,
RDDScanExec(g, emptyRdd, "rdd"))
}.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,135 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest
)
}

// if the keys part of initial state df are different than the keys in the input data, then
// they will not be emitted as part of the result with skipEmittingInitialStateKeys set to true
testWithAllStateVersions("flatMapGroupsWithState - initial state - " +
s"skipEmittingInitialStateKeys=true") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "true") {
val initialState = Seq(
("apple", 1L),
("orange", 2L),
("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map( x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "banana"),
CheckNewAnswer(("apple", 2), ("banana", 1)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 3)),
StopStream
)
}
}

// if the keys part of initial state df are different than the keys in the input data, then
// they will be emitted as part of the result with skipEmittedInitialStateKeys set to false
testWithAllStateVersions("flatMapGroupsWithState - initial state - " +
s"skipEmittingInitialStateKeys=false") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "false") {
val initialState = Seq(
("apple", 1L),
("orange", 2L),
("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map( x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "banana"),
CheckNewAnswer(("apple", 2), ("banana", 1), ("orange", 2), ("mango", 5)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 3)),
StopStream
)
}
}

// if the keys part of the initial state and the first batch are the same, then the result
// is the same irrespective of the value of skipEmittingInitialStateKeys
Seq(true, false).foreach { skipEmittingInitialStateKeys =>
testWithAllStateVersions("flatMapGroupsWithState - initial state and initial batch " +
s"have same keys and skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key ->
skipEmittingInitialStateKeys.toString) {
val initialState = Seq(
("apple", 1L),
("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map(x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "apple"),
AddData(inputData, "orange"),
CheckNewAnswer(("apple", 3), ("orange", 3)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 4)),
StopStream
)
}
}
}

Seq(true, false).foreach { skipEmittingInitialStateKeys =>
testWithAllStateVersions("flatMapGroupsWithState - batch query and " +
s"skipEmittingInitialStateKeys=$skipEmittingInitialStateKeys") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key ->
skipEmittingInitialStateKeys.toString) {
val initialState = Seq(
("apple", 1L),
("orange", 2L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map(x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = Seq("orange", "mango")
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
val df = result.toDF()
if (skipEmittingInitialStateKeys) {
checkAnswer(df, Seq(("orange", 3), ("mango", 1)).toDF())
} else {
checkAnswer(df, Seq(("apple", 1), ("orange", 3), ("mango", 1)).toDF())
}
}
}
}

def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
test(s"$name - state format version $version") {
Expand Down

0 comments on commit cfc3f1f

Please sign in to comment.