Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50967][SS] Add option to skip emitting initial state keys within the FMGWS operator #49632

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[SPARK-50967] Add option to skip emitting initial state keys within t…
…he FMGWS operator
  • Loading branch information
anishshri-db committed Jan 24, 2025
commit 65db30323dd3e514e8a9cb76bff409a8303e461e
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,15 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)

val FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS =
buildConf("spark.sql.streaming.flatMapGroupsWithState.skipEmittingInitialStateKeys")
.internal()
.doc("When true, the flatMapGroupsWithState operation in a streaming query will not emit " +
"results for the initial state keys of each group.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
.doc("The default location for storing checkpoint data for streaming queries.")
.version("2.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,11 +736,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) =>
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val skipEmittingInitialStateKeys =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS)
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
planLater(initialState), hasInitialState, planLater(child)
planLater(initialState), hasInitialState, skipEmittingInitialStateKeys, planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -828,7 +830,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val execPlan = python.FlatMapGroupsInPandasWithStateExec(
func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout,
batchTimestampMs = None, eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None, planLater(child)
eventTimeWatermarkForEviction = None,
skipEmittingInitialStateKeys = false,
planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down Expand Up @@ -953,10 +957,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
val skipEmittingInitialStateKeys =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS)
FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
hasInitialState, skipEmittingInitialStateKeys, planLater(initialState), planLater(child)
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeMode, outputMode, keyEncoder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import org.apache.spark.util.CompletionIterator
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsInPandasWithStateExec(
Expand All @@ -64,6 +65,7 @@ case class FlatMapGroupsInPandasWithStateExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {

// TODO(SPARK-40444): Add the support of initial state.
Expand Down Expand Up @@ -137,7 +139,8 @@ case class FlatMapGroupsInPandasWithStateExec(

override def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
initStateIter: Iterator[InternalRow],
skipEmittingInitialStateKeys: Boolean): Iterator[InternalRow] = {
throw SparkUnsupportedOperationException()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -52,6 +52,7 @@ trait FlatMapGroupsWithStateExecBase
protected val initialStateDataAttrs: Seq[Attribute]
protected val initialState: SparkPlan
protected val hasInitialState: Boolean
protected val skipEmittingInitialStateKeys: Boolean

val stateInfo: Option[StatefulOperatorStateInfo]
protected val stateEncoder: ExpressionEncoder[Any]
Expand Down Expand Up @@ -145,7 +146,8 @@ trait FlatMapGroupsWithStateExecBase

val processedOutputIterator = initialStateIterOption match {
case Some(initStateIter) if initStateIter.hasNext =>
processor.processNewDataWithInitialState(filteredIter, initStateIter)
processor.processNewDataWithInitialState(filteredIter, initStateIter,
skipEmittingInitialStateKeys)
case _ => processor.processNewData(filteredIter)
}

Expand Down Expand Up @@ -301,7 +303,8 @@ trait FlatMapGroupsWithStateExecBase
*/
def processNewDataWithInitialState(
childDataIter: Iterator[InternalRow],
initStateIter: Iterator[InternalRow]
initStateIter: Iterator[InternalRow],
skipEmittingInitialStateKeys: Boolean
): Iterator[InternalRow] = {

if (!childDataIter.hasNext && !initStateIter.hasNext) return Iterator.empty
Expand All @@ -312,10 +315,10 @@ trait FlatMapGroupsWithStateExecBase
val groupedInitialStateIter =
GroupedIterator(initStateIter, initialStateGroupAttrs, initialState.output)

// Create a CoGroupedIterator that will group the two iterators together for every key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
if (skipEmittingInitialStateKeys) {
// If we are skipping emitting initial state keys, we can just process the initial state
// rows to populate the state store and then process the child data rows.
groupedInitialStateIter.foreach { case (keyRow, initialStateRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
Expand All @@ -326,14 +329,40 @@ trait FlatMapGroupsWithStateExecBase
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
// We apply the values for the key after applying the initial state.
}

groupedChildDataIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false
)
valueRowIter,
hasTimedOut = false)
}
} else {
// Create a CoGroupedIterator that will group the two iterators together for every
// key group.
new CoGroupedIterator(
groupedChildDataIter, groupedInitialStateIter, groupingAttributes).flatMap {
case (keyRow, valueRowIter, initialStateRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
if (foundInitialStateForKey) {
FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
val initStateObj = getStateObj.get(initialStateRow)
stateManager.putState(store, keyUnsafeRow, initStateObj, NO_TIMESTAMP)
}
// We apply the values for the key after applying the initial state.
callFunctionAndUpdateState(
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
hasTimedOut = false
)
}
}
}

/** Find the groups that have timeout set and are timing out right now, and call the function */
def processTimedOutState(): Iterator[InternalRow] = {
Expand Down Expand Up @@ -388,6 +417,7 @@ trait FlatMapGroupsWithStateExecBase
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param skipEmittingInitialStateKeys whether to skip emitting initial state df keys
* @param child the physical plan for the underlying data
*/
case class FlatMapGroupsWithStateExec(
Expand All @@ -410,6 +440,7 @@ case class FlatMapGroupsWithStateExec(
eventTimeWatermarkForEviction: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan)
extends FlatMapGroupsWithStateExecBase with BinaryExecNode with ObjectProducerExec {
import GroupStateImpl._
Expand Down Expand Up @@ -533,9 +564,16 @@ object FlatMapGroupsWithStateExec {
outputObjAttr: Attribute,
timeoutConf: GroupStateTimeout,
hasInitialState: Boolean,
skipEmittingInitialStateKeys: Boolean,
initialState: SparkPlan,
child: SparkPlan): SparkPlan = {
if (hasInitialState) {
// we wont support skipping emitting initial state keys for batch queries
// since the underlying CoGroupExec does not support it
if (skipEmittingInitialStateKeys) {
throw SparkUnsupportedOperationException()
}

val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark),
RDDScanExec(g, emptyRdd, "rdd"),
hasInitialState,
false,
RDDScanExec(g, emptyRdd, "rdd"))
}.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,66 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest
)
}

testWithAllStateVersions("flatMapGroupsWithState - initial state - " +
s"skipEmittingInitialStateKeys=true") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "true") {
val initialState = Seq(
("apple", 1L),
("orange", 2L),
("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map( x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "banana"),
CheckNewAnswer(("apple", 2), ("banana", 1)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 3)),
StopStream
)
}
}

testWithAllStateVersions("flatMapGroupsWithState - initial state - " +
s"skipEmittingInitialStateKeys=false") {
withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_SKIP_EMITTING_INITIAL_STATE_KEYS.key -> "false") {
val initialState = Seq(
("apple", 1L),
("orange", 2L),
("mango", 5L)).toDS().groupByKey(_._1).mapValues(_._2)

val fruitCountFunc = (key: String, values: Iterator[String], state: GroupState[Long]) => {
val count = state.getOption.map( x => x).getOrElse(0L) + values.size
state.update(count)
Iterator.single((key, count))
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, NoTimeout(), initialState)(fruitCountFunc)
testStream(result, Update)(
AddData(inputData, "apple"),
AddData(inputData, "banana"),
CheckNewAnswer(("apple", 2), ("banana", 1), ("orange", 2), ("mango", 5)),
AddData(inputData, "orange"),
CheckNewAnswer(("orange", 3)),
StopStream
)
}
}

def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
test(s"$name - state format version $version") {
Expand Down
Loading