Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-40925][SQL][SS] Fix stateful operator late record filtering #38405

Closed
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
acc76dc
x
alex-balikov Oct 13, 2022
8831691
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 13, 2022
59002a2
f
alex-balikov Oct 14, 2022
3f1c322
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 14, 2022
179f422
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 17, 2022
6262312
build fixes
alex-balikov Oct 17, 2022
7c1f066
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 26, 2022
d031e32
revert changes
alex-balikov Oct 26, 2022
c789d19
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 26, 2022
4e12828
fixes
alex-balikov Oct 26, 2022
36b6826
fixes
alex-balikov Oct 26, 2022
764be57
Update sql/core/src/main/scala/org/apache/spark/sql/execution/streami…
alex-balikov Oct 27, 2022
5947fc9
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 27, 2022
bee6185
code review fixes
alex-balikov Oct 27, 2022
0e96fae
Merge branch 'multiple_stateful-ops-base' of https://github.com/alex-…
alex-balikov Oct 27, 2022
565f8af
code review fixes
alex-balikov Oct 27, 2022
24f1b61
fixes
alex-balikov Oct 27, 2022
9515f10
fixes
alex-balikov Oct 27, 2022
03cbdd6
fixes
alex-balikov Oct 27, 2022
572263d
fixes
alex-balikov Oct 27, 2022
68b14e3
Merge remote-tracking branch 'upstream/master' into multiple_stateful…
alex-balikov Oct 28, 2022
6a023ae
fixes
alex-balikov Oct 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,29 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend
with Unevaluable
with NonSQLExpression {

private def inputTypeOnTimeColumn: AbstractDataType = {
Copy link
Contributor

@HeartSaVioR HeartSaVioR Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh I missed that the change from rule side has already merged as a part of introduction of window_time.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, are you asking for anything actionable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. Just for future visibility.

TypeCollection(
AnyTimestampType,
// Below two types cover both time window & session window, since they produce the same type
// of output as window column.
new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType)),
new StructType()
.add(StructField("start", TimestampNTZType))
.add(StructField("end", TimestampNTZType))
)
}

// NOTE: if the window column is given as a time column, we resolve it to the point of time,
// which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not
// be "resolved", so it is safe to not rely on the data type of timeColumn directly.

override def children: Seq[Expression] = Seq(timeColumn, gapDuration)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, AnyDataType)
override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn, AnyDataType)
override def dataType: DataType = new StructType()
.add(StructField("start", timeColumn.dataType))
.add(StructField("end", timeColumn.dataType))
.add(StructField("start", children.head.dataType))
.add(StructField("end", children.head.dataType))

// This expression is replaced in the analyzer.
override lazy val resolved = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,26 @@ case class TimeWindow(
this(timeColumn, windowDuration, windowDuration)
}

private def inputTypeOnTimeColumn: AbstractDataType = {
TypeCollection(
AnyTimestampType,
// Below two types cover both time window & session window, since they produce the same type
// of output as window column.
new StructType()
.add(StructField("start", TimestampType))
.add(StructField("end", TimestampType)),
new StructType()
.add(StructField("start", TimestampNTZType))
.add(StructField("end", TimestampNTZType))
)
}

// NOTE: if the window column is given as a time column, we resolve it to the point of time,
// which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not
// be "resolved", so it is safe to not rely on the data type of timeColumn directly.

override def child: Expression = timeColumn
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType)
override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn)
override def dataType: DataType = new StructType()
.add(StructField("start", child.dataType))
.add(StructField("end", child.dataType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,22 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val STATEFUL_OPERATOR_ALLOW_MULTIPLE =
buildConf("spark.sql.streaming.statefulOperator.allowMultiple")
.internal()
.doc("When true, multiple stateful operators are allowed to be present in a streaming " +
"pipeline. The support for multiple stateful operators introduces a minor (semantically " +
"correct) change in respect to late record filtering - late records are detected and " +
"filtered in respect to the watermark from the previous microbatch instead of the " +
"current one. This is a behavior change for Spark streaming pipelines and we allow " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implies that we apply the same watermark for all stateful operators, right? I'd expect existing tests to be failing since we introduce a behavioral change, but given existing tests all pass, looks like it is due to no-data batch which effectively makes previous watermark to be caught up with next watermark.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it is the same watermark passed to all operators. The issue is if anyone has nit tests which check exactly what records are filtered with carefully constructed batches and Trigger.Once - such tests can detect the change in behavior and fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will have a bunch of errors in test suites when we disable no-data batch. Most test cases are assuming that no-data batch always happens.

"users to revert to the previous behavior of late record filtering (late records are " +
"detected and filtered by comparing with the current microbatch watermark) by setting " +
"the flag value to false. In this mode, only a single stateful operator will be allowed " +
"in a streaming pipeline.")
.version("3.4.0")
.booleanConf
.createWithDefault(true)

val STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION =
buildConf("spark.sql.streaming.statefulOperator.useStrictDistribution")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class QueryExecution(
// output mode does not matter since there is no `Sink`.
new IncrementalExecution(
sparkSession, logical, OutputMode.Append(), "<unknown>",
UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0))
} else {
this
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val execPlan = FlatMapGroupsWithStateExec(
func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child)
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
planLater(initialState), hasInitialState, planLater(child)
)
execPlan :: Nil
case _ =>
Expand All @@ -697,7 +698,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
val execPlan = python.FlatMapGroupsInPandasWithStateExec(
func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout,
batchTimestampMs = None, eventTimeWatermark = None, planLater(child)
batchTimestampMs = None, eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None, planLater(child)
)
execPlan :: Nil
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ object AggUtils {
groupingAttributes,
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
stateFormatVersion = stateFormatVersion,
partialMerged2)

Expand Down Expand Up @@ -472,7 +473,8 @@ object AggUtils {

// shuffle & sort happens here: most of details are also handled in this physical plan
val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes,
sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None,
sessionExpression.toAttribute, stateInfo = None,
eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None,
stateFormatVersion, partialMerged1)

val mergedSessions = {
Expand Down Expand Up @@ -501,7 +503,8 @@ object AggUtils {
sessionExpression.toAttribute,
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
stateFormatVersion, mergedSessions)

val finalAndCompleteAggregate: SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ import org.apache.spark.util.CompletionIterator
* @param outputMode the output mode of `functionExpr`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param child logical plan of the underlying data
*/
case class FlatMapGroupsInPandasWithStateExec(
Expand All @@ -61,9 +62,9 @@ case class FlatMapGroupsInPandasWithStateExec(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
child: SparkPlan)
extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase {
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {

// TODO(SPARK-40444): Add the support of initial state.
override protected val initialStateDeserializer: Expression = null
Expand Down Expand Up @@ -132,7 +133,7 @@ case class FlatMapGroupsInPandasWithStateExec(
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case EventTimeTimeout => eventTimeWatermarkForEviction.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
Expand Down Expand Up @@ -176,7 +177,7 @@ case class FlatMapGroupsInPandasWithStateExec(
val groupedState = GroupStateImpl.createForStreaming(
Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r },
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP),
timeoutConf,
hasTimedOut = hasTimedOut,
watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ trait FlatMapGroupsWithStateExecBase
protected val outputMode: OutputMode
protected val timeoutConf: GroupStateTimeout
protected val batchTimestampMs: Option[Long]
val eventTimeWatermark: Option[Long]

val eventTimeWatermarkForLateEvents: Option[Long]
val eventTimeWatermarkForEviction: Option[Long]
protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout
protected val watermarkPresent: Boolean = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
Expand Down Expand Up @@ -96,7 +96,8 @@ trait FlatMapGroupsWithStateExecBase
true // Always run batches to process timeouts
case EventTimeTimeout =>
// Process another non-data batch only if the watermark has changed in this executed plan
eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
eventTimeWatermarkForEviction.isDefined &&
newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get
case _ =>
false
}
Expand Down Expand Up @@ -125,7 +126,7 @@ trait FlatMapGroupsWithStateExecBase
var timeoutProcessingStartTimeNs = currentTimeNs

// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
val filteredIter = watermarkPredicateForDataForLateEvents match {
case Some(predicate) if timeoutConf == EventTimeTimeout =>
applyRemovingRowsOlderThanWatermark(iter, predicate)
case _ =>
Expand Down Expand Up @@ -189,8 +190,12 @@ trait FlatMapGroupsWithStateExecBase
case ProcessingTimeTimeout =>
require(batchTimestampMs.nonEmpty)
case EventTimeTimeout =>
require(eventTimeWatermark.nonEmpty) // watermark value has been populated
require(watermarkExpression.nonEmpty) // input schema has watermark attribute
// watermark value has been populated
require(eventTimeWatermarkForLateEvents.nonEmpty)
require(eventTimeWatermarkForEviction.nonEmpty)
// input schema has watermark attribute
require(watermarkExpressionForLateEvents.nonEmpty)
require(watermarkExpressionForEviction.nonEmpty)
case _ =>
}

Expand Down Expand Up @@ -310,7 +315,7 @@ trait FlatMapGroupsWithStateExecBase
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
case EventTimeTimeout => eventTimeWatermark.get
case EventTimeTimeout => eventTimeWatermarkForEviction.get
case _ =>
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
Expand Down Expand Up @@ -354,7 +359,8 @@ trait FlatMapGroupsWithStateExecBase
* @param outputMode the output mode of `func`
* @param timeoutConf used to timeout groups that have not received data in a while
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermark event time watermark for the current batch
* @param eventTimeWatermarkForLateEvents event time watermark for filtering late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param initialState the user specified initial state
* @param hasInitialState indicates whether the initial state is provided or not
* @param child the physical plan for the underlying data
Expand All @@ -375,7 +381,8 @@ case class FlatMapGroupsWithStateExec(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
batchTimestampMs: Option[Long],
eventTimeWatermark: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
initialState: SparkPlan,
hasInitialState: Boolean,
child: SparkPlan)
Expand Down Expand Up @@ -410,7 +417,7 @@ case class FlatMapGroupsWithStateExec(
val groupState = GroupStateImpl.createForStreaming(
Option(stateData.stateObj),
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP),
timeoutConf,
hasTimedOut,
watermarkPresent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class IncrementalExecution(
val queryId: UUID,
val runId: UUID,
val currentBatchId: Long,
val prevOffsetSeqMetadata: Option[OffsetSeqMetadata],
val offsetSeqMetadata: OffsetSeqMetadata)
extends QueryExecution(sparkSession, logicalPlan) with Logging {

Expand Down Expand Up @@ -112,6 +113,17 @@ class IncrementalExecution(
numStateStores)
}

// Watermarks to use for late record filtering and state eviction in stateful operators.
// Using the previous watermark for late record filtering is a Spark behavior change so we allow
// this to be disabled.
val eventTimeWatermarkForEviction = offsetSeqMetadata.batchWatermarkMs
val eventTimeWatermarkForLateEvents =
if (sparkSession.conf.get(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) {
prevOffsetSeqMetadata.getOrElse(offsetSeqMetadata).batchWatermarkMs
} else {
eventTimeWatermarkForEviction
}

/** Locates save/restore pairs surrounding aggregation. */
val state = new Rule[SparkPlan] {

Expand Down Expand Up @@ -158,15 +170,16 @@ class IncrementalExecution(
case a: UpdatingSessionsExec if a.isStreaming =>
a.copy(numShufflePartitions = Some(numStateStores))

case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
UnaryExecNode(agg,
StateStoreRestoreExec(_, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
stateFormatVersion,
agg.withNewChildren(
StateStoreRestoreExec(
Expand All @@ -175,58 +188,65 @@ class IncrementalExecution(
stateFormatVersion,
child) :: Nil))

case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion,
case SessionWindowStateStoreSaveExec(keys, session, None, None, None, None,
stateFormatVersion,
UnaryExecNode(agg,
SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) =>
SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
SessionWindowStateStoreSaveExec(
keys,
session,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
stateFormatVersion,
agg.withNewChildren(
SessionWindowStateStoreRestoreExec(
keys,
session,
Some(aggStateInfo),
Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
stateFormatVersion,
child) :: Nil))

case StreamingDeduplicateExec(keys, child, None, None) =>
case StreamingDeduplicateExec(keys, child, None, None, None) =>
StreamingDeduplicateExec(
keys,
child,
Some(nextStatefulOperationStateInfo),
Some(offsetSeqMetadata.batchWatermarkMs))
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction))

case m: FlatMapGroupsWithStateExec =>
// We set this to true only for the first batch of the streaming query.
val hasInitialState = (currentBatchId == 0L && m.hasInitialState)
m.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction),
hasInitialState = hasInitialState
)

case m: FlatMapGroupsInPandasWithStateExec =>
m.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction)
)

case j: StreamingSymmetricHashJoinExec =>
j.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents),
eventTimeWatermarkForEviction = Some(eventTimeWatermarkForLateEvents),
stateWatermarkPredicates =
StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
Some(offsetSeqMetadata.batchWatermarkMs)))
Some(eventTimeWatermarkForEviction)))

case l: StreamingGlobalLimitExec =>
l.copy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ class MicroBatchExecution(
id,
runId,
currentBatchId,
offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1),
offsetSeqMetadata)
lastExecution.executedPlan // Force the lazy generation of execution plan
}
Expand Down
Loading