Skip to content

Commit

Permalink
[SPARK-40925][SQL][SS] Fix stateful operator late record filtering
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR fixes the input late record filtering done by stateful operators to allow for chaining of stateful operators. Currently stateful operators are initialized with the current microbatch watermark and perform both input late record filtering and state eviction (e.g. producing aggregations) using the same watermark value. The state evicted (or aggregates produced) due to watermark advancing is behind the watermark and thus effectively late - if a following stateful operator consumes the output of the previous one, the input records will be filtered as late.

This PR provides two watermark values to the stateful operators - one from the previous microbatch to be used for late record filtering and the one from the current microbatch (as in the existing code) to be used for state eviction. This solves the above problem of the broken late record filtering.

Note that this PR still does not solve the issue of time-interval stream join producing records delayed against the watermark. Therefore time-interval streaming join followed by stateful operators is still not supported. That will be fixed in a follow up PR (and a SPIP) effectively replacing the single global watermark with conceptually watermarks per operator.

Also, the stateful operator chains unblocked by this PR (e.g. a chain of window aggregations) are still blocked by the unsupported operations checker. The new test for these scenarios - MultiStatefulOperatorsSuite has to explicitly disable the unsupported ops check. This again will be fixed in a follow-up PR.

### Why are the changes needed?

The PR allows Spark Structured Streaming to support chaining of stateful operators e.g. chaining of time window aggregations which is a meaningful streaming scenario.

### Does this PR introduce _any_ user-facing change?

With this PR, chains of stateful operators will be supported in Spark Structured Streaming.

### How was this patch tested?

Added a new test suite - MultiStatefulOperatorsSuite

Closes #38405 from alex-balikov/multiple_stateful-ops-base.

Lead-authored-by: Alex Balikov <91913242+alex-balikov@users.noreply.github.com>
Co-authored-by: Alex Balikov <alex.balikov@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
2 people authored and HeartSaVioR committed Oct 31, 2022
1 parent 41ccae0 commit 406d0e2
Show file tree
Hide file tree
Showing 19 changed files with 661 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,29 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend
with Unevaluable
with NonSQLExpression {

private def inputTypeOnTimeColumn: AbstractDataType = {
TypeCollection(
AnyTimestampType,
// Below two types cover both time window & session window, since they produce the same type
// of output as window column.
new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType)),
new StructType()
.add(StructField("start", TimestampNTZType))
.add(StructField("end", TimestampNTZType))
)
}

// NOTE: if the window column is given as a time column, we resolve it to the point of time,
// which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not
// be "resolved", so it is safe to not rely on the data type of timeColumn directly.

override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, AnyDataType)
override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn, AnyDataType)
override def dataType: DataType = new StructType()
.add(StructField("start", timeColumn.dataType))
.add(StructField("end", timeColumn.dataType))
.add(StructField("start", children.head.dataType))
.add(StructField("end", children.head.dataType))

// This expression is replaced in the analyzer.
override lazy val resolved = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,26 @@ case class TimeWindow(
this(timeColumn, windowDuration, windowDuration)
}

private def inputTypeOnTimeColumn: AbstractDataType = {
TypeCollection(
AnyTimestampType,
// Below two types cover both time window & session window, since they produce the same type
// of output as window column.
new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType)),
new StructType()
.add(StructField("start", TimestampNTZType))
.add(StructField("end", TimestampNTZType))
)
}

// NOTE: if the window column is given as a time column, we resolve it to the point of time,
// which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not
// be "resolved", so it is safe to not rely on the data type of timeColumn directly.

override def child: Expression = timeColumn
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType)
override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn)
override def dataType: DataType = new StructType()
.add(StructField("start", child.dataType))
.add(StructField("end", child.dataType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,22 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val STATEFUL_OPERATOR_ALLOW_MULTIPLE =
buildConf("spark.sql.streaming.statefulOperator.allowMultiple")
.internal()
.doc("When true, multiple stateful operators are allowed to be present in a streaming " +
"pipeline. The support for multiple stateful operators introduces a minor (semantically " +
"correct) change in respect to late record filtering - late records are detected and " +
"filtered in respect to the watermark from the previous microbatch instead of the " +
"current one. This is a behavior change for Spark streaming pipelines and we allow " +
"users to revert to the previous behavior of late record filtering (late records are " +
"detected and filtered by comparing with the current microbatch watermark) by setting " +
"the flag value to false. In this mode, only a single stateful operator will be allowed " +
"in a streaming pipeline.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION =
buildConf("spark.sql.streaming.statefulOperator.useStrictDistribution")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class QueryExecution(
// output mode does not matter since there is no `Sink`.
new IncrementalExecution(
sparkSession, logical, OutputMode.Append(), "<unknown>",
UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0))
} else {
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child)
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
planLater(initialState), hasInitialState, planLater(child)
)
execPlan :: Nil
case _ =>
Expand All @@ -697,7 +698,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val execPlan = python.FlatMapGroupsInPandasWithStateExec(
func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout,
batchTimestampMs = None, eventTimeWatermark = None, planLater(child)
batchTimestampMs = None, eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None, planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ object AggUtils {
groupingAttributes,
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
stateFormatVersion = stateFormatVersion,
partialMerged2)

Expand Down Expand Up @@ -472,7 +473,8 @@ object AggUtils {

// shuffle & sort happens here: most of details are also handled in this physical plan
val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes,
sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None,
sessionExpression.toAttribute, stateInfo = None,
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
stateFormatVersion, partialMerged1)

val mergedSessions = {
Expand Down Expand Up @@ -501,7 +503,8 @@ object AggUtils {
sessionExpression.toAttribute,
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
stateFormatVersion, mergedSessions)

val finalAndCompleteAggregate: SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ import org.apache.spark.util.CompletionIterator
* @param outputMode the output mode of `functionExpr`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsInPandasWithStateExec(
Expand All @@ -61,9 +62,9 @@ case class FlatMapGroupsInPandasWithStateExec(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
child: SparkPlan)
extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase {
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {

// TODO(SPARK-40444): Add the support of initial state.
override protected val initialStateDeserializer: Expression = null
Expand Down Expand Up @@ -132,7 +133,7 @@ case class FlatMapGroupsInPandasWithStateExec(
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case EventTimeTimeout => eventTimeWatermarkForEviction.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
Expand Down Expand Up @@ -176,7 +177,7 @@ case class FlatMapGroupsInPandasWithStateExec(
val groupedState = GroupStateImpl.createForStreaming(
Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r },
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP),
timeoutConf,
hasTimedOut = hasTimedOut,
watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ trait FlatMapGroupsWithStateExecBase
protected val outputMode: OutputMode
protected val timeoutConf: GroupStateTimeout
protected val batchTimestampMs: Option[Long]
val eventTimeWatermark: Option[Long]

val eventTimeWatermarkForLateEvents: Option[Long]
val eventTimeWatermarkForEviction: Option[Long]
protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout
protected val watermarkPresent: Boolean = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
Expand Down Expand Up @@ -96,7 +96,8 @@ trait FlatMapGroupsWithStateExecBase
true // Always run batches to process timeouts
case EventTimeTimeout =>
// Process another non-data batch only if the watermark has changed in this executed plan
eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
eventTimeWatermarkForEviction.isDefined &&
newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
case _ =>
false
}
Expand Down Expand Up @@ -125,7 +126,7 @@ trait FlatMapGroupsWithStateExecBase
var timeoutProcessingStartTimeNs = currentTimeNs

// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
val filteredIter = watermarkPredicateForDataForLateEvents match {
case Some(predicate) if timeoutConf == EventTimeTimeout =>
applyRemovingRowsOlderThanWatermark(iter, predicate)
case _ =>
Expand Down Expand Up @@ -189,8 +190,12 @@ trait FlatMapGroupsWithStateExecBase
case ProcessingTimeTimeout =>
require(batchTimestampMs.nonEmpty)
case EventTimeTimeout =>
require(eventTimeWatermark.nonEmpty) // watermark value has been populated
require(watermarkExpression.nonEmpty) // input schema has watermark attribute
// watermark value has been populated
require(eventTimeWatermarkForLateEvents.nonEmpty)
require(eventTimeWatermarkForEviction.nonEmpty)
// input schema has watermark attribute
require(watermarkExpressionForLateEvents.nonEmpty)
require(watermarkExpressionForEviction.nonEmpty)
case _ =>
}

Expand Down Expand Up @@ -310,7 +315,7 @@ trait FlatMapGroupsWithStateExecBase
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case EventTimeTimeout => eventTimeWatermarkForEviction.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
Expand Down Expand Up @@ -354,7 +359,8 @@ trait FlatMapGroupsWithStateExecBase
* @param outputMode the output mode of `func`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param child the physical plan for the underlying data
Expand All @@ -375,7 +381,8 @@ case class FlatMapGroupsWithStateExec(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
child: SparkPlan)
Expand Down Expand Up @@ -410,7 +417,7 @@ case class FlatMapGroupsWithStateExec(
val groupState = GroupStateImpl.createForStreaming(
Option(stateData.stateObj),
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP),
timeoutConf,
hasTimedOut,
watermarkPresent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class IncrementalExecution(
val queryId: UUID,
val runId: UUID,
val currentBatchId: Long,
val prevOffsetSeqMetadata: Option[OffsetSeqMetadata],
val offsetSeqMetadata: OffsetSeqMetadata)
extends QueryExecution(sparkSession, logicalPlan) with Logging {

Expand Down Expand Up @@ -112,6 +113,17 @@ class IncrementalExecution(
numStateStores)
}

// Watermarks to use for late record filtering and state eviction in stateful operators.
// Using the previous watermark for late record filtering is a Spark behavior change so we allow
// this to be disabled.
val eventTimeWatermarkForEviction = offsetSeqMetadata.batchWatermarkMs
val eventTimeWatermarkForLateEvents =
if (sparkSession.conf.get(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) {
prevOffsetSeqMetadata.getOrElse(offsetSeqMetadata).batchWatermarkMs
} else {
eventTimeWatermarkForEviction
}

/** Locates save/restore pairs surrounding aggregation. */
val state = new Rule[SparkPlan] {

Expand Down Expand Up @@ -158,15 +170,16 @@ class IncrementalExecution(
case a: UpdatingSessionsExec if a.isStreaming =>
a.copy(numShufflePartitions = Some(numStateStores))

case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
UnaryExecNode(agg,
StateStoreRestoreExec(_, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
stateFormatVersion,
agg.withNewChildren(
StateStoreRestoreExec(
Expand All @@ -175,58 +188,65 @@ class IncrementalExecution(
stateFormatVersion,
child) :: Nil))

case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion,
case SessionWindowStateStoreSaveExec(keys, session, None, None, None, None,
stateFormatVersion,
UnaryExecNode(agg,
SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) =>
SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
SessionWindowStateStoreSaveExec(
keys,
session,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
stateFormatVersion,
agg.withNewChildren(
SessionWindowStateStoreRestoreExec(
keys,
session,
Some(aggStateInfo),
Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
stateFormatVersion,
child) :: Nil))

case StreamingDeduplicateExec(keys, child, None, None) =>
case StreamingDeduplicateExec(keys, child, None, None, None) =>
StreamingDeduplicateExec(
keys,
child,
Some(nextStatefulOperationStateInfo),
Some(offsetSeqMetadata.batchWatermarkMs))
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction))

case m: FlatMapGroupsWithStateExec =>
// We set this to true only for the first batch of the streaming query.
val hasInitialState = (currentBatchId == 0L && m.hasInitialState)
m.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
hasInitialState = hasInitialState
)

case m: FlatMapGroupsInPandasWithStateExec =>
m.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction)
)

case j: StreamingSymmetricHashJoinExec =>
j.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForLateEvents),
stateWatermarkPredicates =
StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
Some(offsetSeqMetadata.batchWatermarkMs)))
Some(eventTimeWatermarkForEviction)))

case l: StreamingGlobalLimitExec =>
l.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ class MicroBatchExecution(
id,
runId,
currentBatchId,
offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1),
offsetSeqMetadata)
lastExecution.executedPlan // Force the lazy generation of execution plan
}
Expand Down
Loading

0 comments on commit 406d0e2

Please sign in to comment.