diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala index 77e8dfde87bbb..02273b0c46134 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -68,11 +68,29 @@ case class SessionWindow(timeColumn: Expression, gapDuration: Expression) extend with Unevaluable with NonSQLExpression { + private def inputTypeOnTimeColumn: AbstractDataType = { + TypeCollection( + AnyTimestampType, + // Below two types cover both time window & session window, since they produce the same type + // of output as window column. + new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)), + new StructType() + .add(StructField("start", TimestampNTZType)) + .add(StructField("end", TimestampNTZType)) + ) + } + + // NOTE: if the window column is given as a time column, we resolve it to the point of time, + // which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not + // be "resolved", so it is safe to not rely on the data type of timeColumn directly. + override def children: Seq[Expression] = Seq(timeColumn, gapDuration) - override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType, AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn, AnyDataType) override def dataType: DataType = new StructType() - .add(StructField("start", timeColumn.dataType)) - .add(StructField("end", timeColumn.dataType)) + .add(StructField("start", children.head.dataType)) + .add(StructField("end", children.head.dataType)) // This expression is replaced in the analyzer. override lazy val resolved = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 93c1074dfbede..bc9b7de7464e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -96,8 +96,26 @@ case class TimeWindow( this(timeColumn, windowDuration, windowDuration) } + private def inputTypeOnTimeColumn: AbstractDataType = { + TypeCollection( + AnyTimestampType, + // Below two types cover both time window & session window, since they produce the same type + // of output as window column. + new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)), + new StructType() + .add(StructField("start", TimestampNTZType)) + .add(StructField("end", TimestampNTZType)) + ) + } + + // NOTE: if the window column is given as a time column, we resolve it to the point of time, + // which resolves to either TimestampType or TimestampNTZType. That means, timeColumn may not + // be "resolved", so it is safe to not rely on the data type of timeColumn directly. + override def child: Expression = timeColumn - override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimestampType) + override def inputTypes: Seq[AbstractDataType] = Seq(inputTypeOnTimeColumn) override def dataType: DataType = new StructType() .add(StructField("start", child.dataType)) .add(StructField("end", child.dataType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0a60c6b0265af..3854f3190a764 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1941,6 +1941,22 @@ object SQLConf { .booleanConf .createWithDefault(true) + val STATEFUL_OPERATOR_ALLOW_MULTIPLE = + buildConf("spark.sql.streaming.statefulOperator.allowMultiple") + .internal() + .doc("When true, multiple stateful operators are allowed to be present in a streaming " + + "pipeline. The support for multiple stateful operators introduces a minor (semantically " + + "correct) change in respect to late record filtering - late records are detected and " + + "filtered in respect to the watermark from the previous microbatch instead of the " + + "current one. This is a behavior change for Spark streaming pipelines and we allow " + + "users to revert to the previous behavior of late record filtering (late records are " + + "detected and filtered by comparing with the current microbatch watermark) by setting " + + "the flag value to false. In this mode, only a single stateful operator will be allowed " + + "in a streaming pipeline.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION = buildConf("spark.sql.streaming.statefulOperator.useStrictDistribution") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 10b763b1b5134..8bf5d3d317bdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -229,7 +229,7 @@ class QueryExecution( // output mode does not matter since there is no `Sink`. new IncrementalExecution( sparkSession, logical, OutputMode.Append(), "", - UUID.randomUUID, UUID.randomUUID, 0, OffsetSeqMetadata(0, 0)) + UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0)) } else { this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 03e722a86fb21..b96e47846fc93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -678,7 +678,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr, None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None, - eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child) + eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, + planLater(initialState), hasInitialState, planLater(child) ) execPlan :: Nil case _ => @@ -697,7 +698,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) val execPlan = python.FlatMapGroupsInPandasWithStateExec( func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, - batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + batchTimestampMs = None, eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, planLater(child) ) execPlan :: Nil case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 579a00c7996f2..557f0e897ee40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -369,7 +369,8 @@ object AggUtils { groupingAttributes, stateInfo = None, outputMode = None, - eventTimeWatermark = None, + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, stateFormatVersion = stateFormatVersion, partialMerged2) @@ -472,7 +473,8 @@ object AggUtils { // shuffle & sort happens here: most of details are also handled in this physical plan val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, - sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + sessionExpression.toAttribute, stateInfo = None, + eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, stateFormatVersion, partialMerged1) val mergedSessions = { @@ -501,7 +503,8 @@ object AggUtils { sessionExpression.toAttribute, stateInfo = None, outputMode = None, - eventTimeWatermark = None, + eventTimeWatermarkForLateEvents = None, + eventTimeWatermarkForEviction = None, stateFormatVersion, mergedSessions) val finalAndCompleteAggregate: SparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 3b096f07241fc..bc1a5ae17e4d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -48,7 +48,8 @@ import org.apache.spark.util.CompletionIterator * @param outputMode the output mode of `functionExpr` * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermark event time watermark for the current batch + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction * @param child logical plan of the underlying data */ case class FlatMapGroupsInPandasWithStateExec( @@ -61,9 +62,9 @@ case class FlatMapGroupsInPandasWithStateExec( outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], - child: SparkPlan) - extends UnaryExecNode with PythonSQLMetrics with FlatMapGroupsWithStateExecBase { + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { // TODO(SPARK-40444): Add the support of initial state. override protected val initialStateDeserializer: Expression = null @@ -132,7 +133,7 @@ case class FlatMapGroupsInPandasWithStateExec( if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get - case EventTimeTimeout => eventTimeWatermark.get + case EventTimeTimeout => eventTimeWatermarkForEviction.get case _ => throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") @@ -176,7 +177,7 @@ case class FlatMapGroupsInPandasWithStateExec( val groupedState = GroupStateImpl.createForStreaming( Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), + eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut = hasTimedOut, watermarkPresent).asInstanceOf[GroupStateImpl[Row]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 790a652f21124..138029e76c118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -54,8 +54,8 @@ trait FlatMapGroupsWithStateExecBase protected val outputMode: OutputMode protected val timeoutConf: GroupStateTimeout protected val batchTimestampMs: Option[Long] - val eventTimeWatermark: Option[Long] - + val eventTimeWatermarkForLateEvents: Option[Long] + val eventTimeWatermarkForEviction: Option[Long] protected val isTimeoutEnabled: Boolean = timeoutConf != NoTimeout protected val watermarkPresent: Boolean = child.output.exists { case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true @@ -96,7 +96,8 @@ trait FlatMapGroupsWithStateExecBase true // Always run batches to process timeouts case EventTimeTimeout => // Process another non-data batch only if the watermark has changed in this executed plan - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get case _ => false } @@ -125,7 +126,7 @@ trait FlatMapGroupsWithStateExecBase var timeoutProcessingStartTimeNs = currentTimeNs // If timeout is based on event time, then filter late data based on watermark - val filteredIter = watermarkPredicateForData match { + val filteredIter = watermarkPredicateForDataForLateEvents match { case Some(predicate) if timeoutConf == EventTimeTimeout => applyRemovingRowsOlderThanWatermark(iter, predicate) case _ => @@ -189,8 +190,12 @@ trait FlatMapGroupsWithStateExecBase case ProcessingTimeTimeout => require(batchTimestampMs.nonEmpty) case EventTimeTimeout => - require(eventTimeWatermark.nonEmpty) // watermark value has been populated - require(watermarkExpression.nonEmpty) // input schema has watermark attribute + // watermark value has been populated + require(eventTimeWatermarkForLateEvents.nonEmpty) + require(eventTimeWatermarkForEviction.nonEmpty) + // input schema has watermark attribute + require(watermarkExpressionForLateEvents.nonEmpty) + require(watermarkExpressionForEviction.nonEmpty) case _ => } @@ -310,7 +315,7 @@ trait FlatMapGroupsWithStateExecBase if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get - case EventTimeTimeout => eventTimeWatermark.get + case EventTimeTimeout => eventTimeWatermarkForEviction.get case _ => throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") @@ -354,7 +359,8 @@ trait FlatMapGroupsWithStateExecBase * @param outputMode the output mode of `func` * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. - * @param eventTimeWatermark event time watermark for the current batch + * @param eventTimeWatermarkForLateEvents event time watermark for filtering late events + * @param eventTimeWatermarkForEviction event time watermark for state eviction * @param initialState the user specified initial state * @param hasInitialState indicates whether the initial state is provided or not * @param child the physical plan for the underlying data @@ -375,7 +381,8 @@ case class FlatMapGroupsWithStateExec( outputMode: OutputMode, timeoutConf: GroupStateTimeout, batchTimestampMs: Option[Long], - eventTimeWatermark: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], initialState: SparkPlan, hasInitialState: Boolean, child: SparkPlan) @@ -410,7 +417,7 @@ case class FlatMapGroupsWithStateExec( val groupState = GroupStateImpl.createForStreaming( Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), - eventTimeWatermark.getOrElse(NO_TIMESTAMP), + eventTimeWatermarkForEviction.getOrElse(NO_TIMESTAMP), timeoutConf, hasTimedOut, watermarkPresent) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index f386282a0b3e6..574709d05b0da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -48,6 +48,7 @@ class IncrementalExecution( val queryId: UUID, val runId: UUID, val currentBatchId: Long, + val prevOffsetSeqMetadata: Option[OffsetSeqMetadata], val offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { @@ -112,6 +113,17 @@ class IncrementalExecution( numStateStores) } + // Watermarks to use for late record filtering and state eviction in stateful operators. + // Using the previous watermark for late record filtering is a Spark behavior change so we allow + // this to be disabled. + val eventTimeWatermarkForEviction = offsetSeqMetadata.batchWatermarkMs + val eventTimeWatermarkForLateEvents = + if (sparkSession.conf.get(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)) { + prevOffsetSeqMetadata.getOrElse(offsetSeqMetadata).batchWatermarkMs + } else { + eventTimeWatermarkForEviction + } + /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -158,7 +170,7 @@ class IncrementalExecution( case a: UpdatingSessionsExec if a.isStreaming => a.copy(numShufflePartitions = Some(numStateStores)) - case StateStoreSaveExec(keys, None, None, None, stateFormatVersion, + case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion, UnaryExecNode(agg, StateStoreRestoreExec(_, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo @@ -166,7 +178,8 @@ class IncrementalExecution( keys, Some(aggStateInfo), Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), stateFormatVersion, agg.withNewChildren( StateStoreRestoreExec( @@ -175,32 +188,36 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) - case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, None, + stateFormatVersion, UnaryExecNode(agg, - SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => + SessionWindowStateStoreRestoreExec(_, _, None, None, None, _, child))) => val aggStateInfo = nextStatefulOperationStateInfo SessionWindowStateStoreSaveExec( keys, session, Some(aggStateInfo), Some(outputMode), - Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), stateFormatVersion, agg.withNewChildren( SessionWindowStateStoreRestoreExec( keys, session, Some(aggStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), stateFormatVersion, child) :: Nil)) - case StreamingDeduplicateExec(keys, child, None, None) => + case StreamingDeduplicateExec(keys, child, None, None, None) => StreamingDeduplicateExec( keys, child, Some(nextStatefulOperationStateInfo), - Some(offsetSeqMetadata.batchWatermarkMs)) + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction)) case m: FlatMapGroupsWithStateExec => // We set this to true only for the first batch of the streaming query. @@ -208,7 +225,8 @@ class IncrementalExecution( m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction), hasInitialState = hasInitialState ) @@ -216,17 +234,19 @@ class IncrementalExecution( m.copy( stateInfo = Some(nextStatefulOperationStateInfo), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForEviction) ) case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), - eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), + eventTimeWatermarkForLateEvents = Some(eventTimeWatermarkForLateEvents), + eventTimeWatermarkForEviction = Some(eventTimeWatermarkForLateEvents), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, - Some(offsetSeqMetadata.batchWatermarkMs))) + Some(eventTimeWatermarkForEviction))) case l: StreamingGlobalLimitExec => l.copy( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 5f8fb93827b32..7ed19b3511476 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -695,6 +695,7 @@ class MicroBatchExecution( id, runId, currentBatchId, + offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1), offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 82e50263893db..7f00717ea4df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -102,6 +102,10 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) } } } + + def offsetSeqMetadataForBatchId(batchId: Long): Option[OffsetSeqMetadata] = { + if (batchId < 0) None else get(batchId).flatMap(_.metadata) + } } object OffsetSeqLog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 4a8f3b18c098a..dfde4156812b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -119,7 +119,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * @param condition Conditions to filter rows, split by left, right, and joined. See * [[JoinConditionSplitPredicates]] * @param stateInfo Version information required to read join state (buffered rows) - * @param eventTimeWatermark Watermark of input event, same for both sides + * @param eventTimeWatermarkForLateEvents Watermark for filtering late events, same for both sides + * @param eventTimeWatermarkForEviction Watermark for state eviction * @param stateWatermarkPredicates Predicates for removal of state, see * [[JoinStateWatermarkPredicates]] * @param left Left child plan @@ -131,7 +132,8 @@ case class StreamingSymmetricHashJoinExec( joinType: JoinType, condition: JoinConditionSplitPredicates, stateInfo: Option[StatefulOperatorStateInfo], - eventTimeWatermark: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, stateFormatVersion: Int, left: SparkPlan, @@ -148,7 +150,8 @@ case class StreamingSymmetricHashJoinExec( this( leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), - stateInfo = None, eventTimeWatermark = None, + stateInfo = None, + eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, stateWatermarkPredicates = JoinStateWatermarkPredicates(), stateFormatVersion, left, right) } @@ -222,7 +225,8 @@ case class StreamingSymmetricHashJoinExec( // Latest watermark value is more than that used in this previous executed plan val watermarkHasChanged = - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get watermarkUsedForStateCleanup && watermarkHasChanged } @@ -555,7 +559,8 @@ case class StreamingSymmetricHashJoinExec( val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = - WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { + WatermarkSupport.watermarkExpression( + watermarkAttribute, eventTimeWatermarkForLateEvents) match { case Some(watermarkExpr) => val predicate = Predicate.create(watermarkExpr, inputAttributes) applyRemovingRowsOlderThanWatermark(inputIter, predicate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index 2f62dbd7ec578..7bf6381e08ffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -137,7 +137,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { leftKeys: Seq[Expression], rightKeys: Seq[Expression], condition: Option[Expression], - eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates = { + eventTimeWatermarkForEviction: Option[Long]): JoinStateWatermarkPredicates = { // Join keys of both sides generate rows of the same fields, that is, same sequence of data @@ -172,7 +172,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { joinKeyOrdinalForWatermark.get, oneSideJoinKeys(joinKeyOrdinalForWatermark.get).dataType, oneSideJoinKeys(joinKeyOrdinalForWatermark.get).nullable) - val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermark) + val expr = watermarkExpression(Some(keyExprWithWatermark), eventTimeWatermarkForEviction) expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs @@ -180,7 +180,7 @@ object StreamingSymmetricHashJoinHelper extends Logging { attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, - eventTimeWatermark) + eventTimeWatermarkForEviction) val inputAttributeWithWatermark = oneSideInputAttributes.find(_.metadata.contains(delayKey)) val expr = watermarkExpression(inputAttributeWithWatermark, stateValueWatermark) expr.map(JoinStateValueWatermarkPredicate.apply _) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 5b620eec25feb..e8092e072bc22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -218,6 +218,7 @@ class ContinuousExecution( id, runId, currentBatchId, + None, offsetSeqMetadata) lastExecution.executedPlan // Force the lazy generation of execution plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 0a603a3b14139..3f474ea533ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactor */ class MicroBatchWrite(epochId: Long, val writeSupport: StreamingWrite) extends BatchWrite { override def toString: String = { - s"MicroBathWrite[epoch: $epochId, writer: $writeSupport]" + s"MicroBatchWrite[epoch: $epochId, writer: $writeSupport]" } override def commit(messages: Array[WriterCommitMessage]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index b540f9f00939a..457e5f80ae6bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -212,34 +212,73 @@ trait WatermarkSupport extends SparkPlan { /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] - /** The watermark value. */ - def eventTimeWatermark: Option[Long] + /** + * The watermark value for filtering late events/records. This should be the previous + * batch state eviction watermark. + */ + def eventTimeWatermarkForLateEvents: Option[Long] + /** + * The watermark value for closing aggregates and evicting state. + * It is different from the late events filtering watermark (consider chained aggregators + * agg1 -> agg2: agg1 evicts state which will be effectively late against the eviction watermark + * but should not be late for agg2 input late record filtering watermark. Thus agg1 and agg2 use + * the current batch watermark for state eviction but the previous batch watermark for late + * record filtering. + */ + def eventTimeWatermarkForEviction: Option[Long] + + /** Generate an expression that matches data older than late event filtering watermark */ + lazy val watermarkExpressionForLateEvents: Option[Expression] = + watermarkExpression(eventTimeWatermarkForLateEvents) + /** Generate an expression that matches data older than the state eviction watermark */ + lazy val watermarkExpressionForEviction: Option[Expression] = + watermarkExpression(eventTimeWatermarkForEviction) /** Generate an expression that matches data older than the watermark */ - lazy val watermarkExpression: Option[Expression] = { + private def watermarkExpression(watermark: Option[Long]): Option[Expression] = { WatermarkSupport.watermarkExpression( - child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)), - eventTimeWatermark) + child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)), watermark) } - /** Predicate based on keys that matches data older than the watermark */ - lazy val watermarkPredicateForKeys: Option[BasePredicate] = watermarkExpression.flatMap { e => - if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { - Some(Predicate.create(e, keyExpressions)) - } else { - None + /** Predicate based on keys that matches data older than the late event filtering watermark */ + lazy val watermarkPredicateForKeysForLateEvents: Option[BasePredicate] = + watermarkPredicateForKeys(watermarkExpressionForLateEvents) + + /** Generate an expression that matches data older than the state eviction watermark */ + lazy val watermarkPredicateForKeysForEviction: Option[BasePredicate] = + watermarkPredicateForKeys(watermarkExpressionForEviction) + + private def watermarkPredicateForKeys( + watermarkExpression: Option[Expression]): Option[BasePredicate] = { + watermarkExpression.flatMap { e => + if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) { + Some(Predicate.create(e, keyExpressions)) + } else { + None + } } } - /** Predicate based on the child output that matches data older than the watermark. */ - lazy val watermarkPredicateForData: Option[BasePredicate] = + /** + * Predicate based on the child output that matches data older than the watermark for late events + * filtering. + */ + lazy val watermarkPredicateForDataForLateEvents: Option[BasePredicate] = + watermarkPredicateForData(watermarkExpressionForLateEvents) + + lazy val watermarkPredicateForDataForEviction: Option[BasePredicate] = + watermarkPredicateForData(watermarkExpressionForEviction) + + private def watermarkPredicateForData( + watermarkExpression: Option[Expression]): Option[BasePredicate] = { watermarkExpression.map(Predicate.create(_, child.output)) + } protected def removeKeysOlderThanWatermark(store: StateStore): Unit = { - if (watermarkPredicateForKeys.nonEmpty) { + if (watermarkPredicateForKeysForEviction.nonEmpty) { val numRemovedStateRows = longMetric("numRemovedStateRows") store.iterator().foreach { rowPair => - if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + if (watermarkPredicateForKeysForEviction.get.eval(rowPair.key)) { store.remove(rowPair.key) numRemovedStateRows += 1 } @@ -250,10 +289,10 @@ trait WatermarkSupport extends SparkPlan { protected def removeKeysOlderThanWatermark( storeManager: StreamingAggregationStateManager, store: StateStore): Unit = { - if (watermarkPredicateForKeys.nonEmpty) { + if (watermarkPredicateForKeysForEviction.nonEmpty) { val numRemovedStateRows = longMetric("numRemovedStateRows") storeManager.keys(store).foreach { keyRow => - if (watermarkPredicateForKeys.get.eval(keyRow)) { + if (watermarkPredicateForKeysForEviction.get.eval(keyRow)) { storeManager.remove(store, keyRow) numRemovedStateRows += 1 } @@ -354,7 +393,8 @@ case class StateStoreSaveExec( keyExpressions: Seq[Attribute], stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, - eventTimeWatermark: Option[Long] = None, + eventTimeWatermarkForLateEvents: Option[Long] = None, + eventTimeWatermarkForEviction: Option[Long] = None, stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -407,7 +447,7 @@ case class StateStoreSaveExec( case Some(Append) => allUpdatesTimeMs += timeTakenMs { val filteredIter = applyRemovingRowsOlderThanWatermark(iter, - watermarkPredicateForData.get) + watermarkPredicateForDataForLateEvents.get) while (filteredIter.hasNext) { val row = filteredIter.next().asInstanceOf[UnsafeRow] stateManager.put(store, row) @@ -423,7 +463,7 @@ case class StateStoreSaveExec( var removedValueRow: InternalRow = null while(rangeIter.hasNext && removedValueRow == null) { val rowPair = rangeIter.next() - if (watermarkPredicateForKeys.get.eval(rowPair.key)) { + if (watermarkPredicateForKeysForEviction.get.eval(rowPair.key)) { stateManager.remove(store, rowPair.key) numRemovedStateRows += 1 removedValueRow = rowPair.value @@ -453,7 +493,7 @@ case class StateStoreSaveExec( new NextIterator[InternalRow] { // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicateForData match { + private[this] val baseIterator = watermarkPredicateForDataForLateEvents match { case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) case None => iter } @@ -507,8 +547,8 @@ case class StateStoreSaveExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && - eventTimeWatermark.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get } override protected def withNewChildInternal(newChild: SparkPlan): StateStoreSaveExec = @@ -525,7 +565,8 @@ case class SessionWindowStateStoreRestoreExec( keyWithoutSessionExpressions: Seq[Attribute], sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo], - eventTimeWatermark: Option[Long], + eventTimeWatermarkForLateEvents: Option[Long], + eventTimeWatermarkForEviction: Option[Long], stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreReader with WatermarkSupport { @@ -555,7 +596,7 @@ case class SessionWindowStateStoreRestoreExec( Some(session.streams.stateStoreCoordinator)) { case (store, iter) => // We need to filter out outdated inputs - val filteredIterator = watermarkPredicateForData match { + val filteredIterator = watermarkPredicateForDataForLateEvents match { case Some(predicate) => iter.filter((row: InternalRow) => { val shouldKeep = !predicate.eval(row) if (!shouldKeep) longMetric("numRowsDroppedByWatermark") += 1 @@ -611,7 +652,8 @@ case class SessionWindowStateStoreSaveExec( sessionExpression: Attribute, stateInfo: Option[StatefulOperatorStateInfo] = None, outputMode: Option[OutputMode] = None, - eventTimeWatermark: Option[Long] = None, + eventTimeWatermarkForLateEvents: Option[Long] = None, + eventTimeWatermarkForEviction: Option[Long] = None, stateFormatVersion: Int, child: SparkPlan) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { @@ -667,7 +709,7 @@ case class SessionWindowStateStoreSaveExec( val removalStartTimeNs = System.nanoTime new NextIterator[InternalRow] { private val removedIter = stateManager.removeByValueCondition( - store, watermarkPredicateForData.get.eval) + store, watermarkPredicateForDataForEviction.get.eval) override protected def getNext(): InternalRow = { if (!removedIter.hasNext) { @@ -704,8 +746,8 @@ case class SessionWindowStateStoreSaveExec( override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { (outputMode.contains(Append) || outputMode.contains(Update)) && - eventTimeWatermark.isDefined && - newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get } private def putToStore(iter: Iterator[InternalRow], store: StateStore): Unit = { @@ -775,7 +817,8 @@ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], child: SparkPlan, stateInfo: Option[StatefulOperatorStateInfo] = None, - eventTimeWatermark: Option[Long] = None) + eventTimeWatermarkForLateEvents: Option[Long] = None, + eventTimeWatermarkForEviction: Option[Long] = None) extends UnaryExecNode with StateStoreWriter with WatermarkSupport { /** Distribute by grouping attributes */ @@ -807,7 +850,7 @@ case class StreamingDeduplicateExec( val commitTimeMs = longMetric("commitTimeMs") val numDroppedDuplicateRows = longMetric("numDroppedDuplicateRows") - val baseIterator = watermarkPredicateForData match { + val baseIterator = watermarkPredicateForDataForLateEvents match { case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) case None => iter } @@ -851,7 +894,8 @@ case class StreamingDeduplicateExec( override def shortName: String = "dedupe" override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { - eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + eventTimeWatermarkForEviction.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermarkForEviction.get } override protected def withNewChildInternal(newChild: SparkPlan): StreamingDeduplicateExec = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 14f083bbd307a..49f4214ac1ae0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -1048,7 +1048,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { hasInitialState, sga, sda, se, i, c) => FlatMapGroupsWithStateExec( f, k, v, se, g, sga, d, sda, o, None, s, stateFormatVersion, m, t, - Some(currentBatchTimestamp), Some(currentBatchWatermark), + Some(currentBatchTimestamp), Some(0), Some(currentBatchWatermark), RDDScanExec(g, emptyRdd, "rdd"), hasInitialState, RDDScanExec(g, emptyRdd, "rdd")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala new file mode 100644 index 0000000000000..0a3ea40a677ad --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ + +// Tests for the multiple stateful operators support. +class MultiStatefulOperatorsSuite + extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + test("window agg -> window agg, append mode") { + // TODO: SPARK-40940 - Fix the unsupported ops checker to allow chaining of stateful ops. + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val stream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .groupBy(window($"window", "10 seconds")) + .agg(count("*").as("count"), sum("count").as("sum")) + .select($"window".getField("start").cast("long").as[Long], + $"count".as[Long], $"sum".as[Long]) + + testStream(stream)( + AddData(inputData, 10 to 21: _*), + // op1 W (0, 0) + // agg: [10, 15) 5, [15, 20) 5, [20, 25) 2 + // output: None + // state: [10, 15) 5, [15, 20) 5, [20, 25) 2 + // op2 W (0, 0) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 21) + // agg: None + // output: [10, 15) 5, [15, 20) 5 + // state: [20, 25) 2 + // op2 W (0, 21) + // agg: [10, 20) (2, 10) + // output: [10, 20) (2, 10) + // state: None + CheckNewAnswer((10, 2, 10)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)), + + AddData(inputData, 10 to 29: _*), + // op1 W (21, 21) + // agg: [10, 15) 5 - late, [15, 20) 5 - late, [20, 25) 5, [25, 30) 5 + // output: None + // state: [20, 25) 7, [25, 30) 5 + // op2 W (21, 21) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (21, 29) + // agg: None + // output: [20, 25) 7 + // state: [25, 30) 5 + // op2 W (21, 29) + // agg: [20, 30) (1, 7) + // output: None + // state: [20, 30) (1, 7) + CheckNewAnswer(), + assertNumStateRows(Seq(1, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 2)), + + // Move the watermark. + AddData(inputData, 30, 31), + // op1 W (29, 29) + // agg: [30, 35) 2 + // output: None + // state: [25, 30) 5 [30, 35) 2 + // op2 W (29, 29) + // agg: None + // output: None + // state: [20, 30) (1, 7) + + // no-data batch triggered + + // op1 W (29, 31) + // agg: None + // output: [25, 30) 5 + // state: [30, 35) 2 + // op2 W (29, 31) + // agg: [20, 30) (2, 12) + // output: [20, 30) (2, 12) + // state: None + CheckNewAnswer((20, 2, 12)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + test("agg -> agg -> agg, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val stream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .groupBy(window(window_time($"window"), "10 seconds")) + .agg(count("*").as("count"), sum("count").as("sum")) + .groupBy(window(window_time($"window"), "20 seconds")) + .agg(count("*").as("count"), sum("sum").as("sum")) + .select( + $"window".getField("start").cast("long").as[Long], + $"window".getField("end").cast("long").as[Long], + $"count".as[Long], $"sum".as[Long]) + + testStream(stream)( + AddData(inputData, 0 to 37: _*), + // op1 W (0, 0) + // agg: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, + // [35, 40) 3 + // output: None + // state: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, + // [35, 40) 3 + // op2 W (0, 0) + // agg: None + // output: None + // state: None + // op3 W (0, 0) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 37) + // agg: None + // output: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5 + // state: [35, 40) 3 + // op2 W (0, 37) + // agg: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10), [30, 40) (1, 5) + // output: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10) + // state: [30, 40) (1, 5) + // op3 W (0, 37) + // agg: [0, 20) (2, 20), [20, 40) (1, 10) + // output: [0, 20) (2, 20) + // state: [20, 40) (1, 10) + CheckNewAnswer((0, 20, 2, 20)), + assertNumStateRows(Seq(1, 1, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0, 0)), + + AddData(inputData, 30 to 60: _*), + // op1 W (37, 37) + // dropped rows: [30, 35), 1 row <= note that 35, 36, 37 are still in effect + // agg: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 + // output: None + // state: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 + // op2 W (37, 37) + // output: None + // state: [30, 40) (1, 5) + // op3 W (37, 37) + // output: None + // state: [20, 40) (1, 10) + + // no-data batch + // op1 W (37, 60) + // output: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5 + // state: [60, 65) 1 + // op2 W (37, 60) + // agg: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) + // output: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) + // state: None + // op3 W (37, 60) + // agg: [20, 40) (2, 23), [40, 60) (2, 20) + // output: [20, 40) (2, 23), [40, 60) (2, 20) + // state: None + + CheckNewAnswer((20, 40, 2, 23), (40, 60, 2, 20)), + assertNumStateRows(Seq(0, 0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0, 1)) + ) + } + } + + test("stream deduplication -> aggregation, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val deduplication = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("value", "eventTime") + + val windowedAggregation = deduplication + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count"), sum("value").as("sum")) + .select($"window".getField("start").cast("long").as[Long], + $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 1 to 15: _*), + // op1 W (0, 0) + // input: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // deduplicated: None + // output: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // state: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // op2 W (0, 0) + // agg: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 + // output: None + // state: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 + + // no-data batch triggered + + // op1 W (0, 5) + // agg: None + // output: None + // state: 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // op2 W (0, 5) + // agg: None + // output: [0, 5) 4 + // state: [5, 10) 5 [10, 15) 5, [15, 20) 1 + CheckNewAnswer((0, 4)), + assertNumStateRows(Seq(3, 10)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + test("join -> window agg, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "0 seconds") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "0 seconds") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .groupBy(window($"eventTime1", "5 seconds").as("window")) + .agg(count("*").as("count")) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(stream)( + MultiAddData(input1, 1 to 4: _*)(input2, 1 to 4: _*), + + // op1 W (0, 0) + // join output: (1, 1), (2, 2), (3, 3), (4, 4) + // state: (1, 1), (2, 2), (3, 3), (4, 4) + // op2 W (0, 0) + // agg: [0, 5) 4 + // output: None + // state: [0, 5) 4 + + // no-data batch triggered + + // op1 W (0, 4) + // join output: None + // state: None + // op2 W (0, 4) + // agg: None + // output: None + // state: [0, 5) 4 + CheckNewAnswer(), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)), + + // Move the watermark + MultiAddData(input1, 5)(input2, 5), + + // op1 W (4, 4) + // join output: (5, 5) + // state: (5, 5) + // op2 W (4, 4) + // agg: [5, 10) 1 + // output: None + // state: [0, 5) 4, [5, 10) 1 + + // no-data batch triggered + + // op1 W (4, 5) + // join output: None + // state: None + // op2 W (4, 5) + // agg: None + // output: [0, 5) 4 + // state: [5, 10) 1 + CheckNewAnswer((0, 4)), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + test("aggregation -> stream deduplication, append mode") { + withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { + val inputData = MemoryStream[Int] + + val aggStream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .withColumn("windowEnd", expr("window.end")) + + // dropDuplicates from aggStream without event time column for dropDuplicates - the + // state does not get trimmed due to watermark advancement. + val dedupNoEventTime = aggStream + .dropDuplicates("count", "windowEnd") + .select( + $"windowEnd".cast("long").as[Long], + $"count".as[Long]) + + testStream(dedupNoEventTime)( + AddData(inputData, 1, 5, 10, 15), + + // op1 W (0, 0) + // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // output: None + // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // op2 W (0, 0) + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 15) + // agg: None + // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 + // state: [15, 20) 1 + // op2 W (0, 15) + // output: (5, 1), (10, 1), (15, 1) + // state: (5, 1), (10, 1), (15, 1) + + CheckNewAnswer((5, 1), (10, 1), (15, 1)), + assertNumStateRows(Seq(3, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + + // Similar to the above but add event time. The dedup state will get trimmed. + val dedupWithEventTime = aggStream + .withColumn("windowTime", expr("window_time(window)")) + .withColumn("windowTimeMicros", expr("unix_micros(windowTime)")) + .dropDuplicates("count", "windowEnd", "windowTime") + .select( + $"windowEnd".cast("long").as[Long], + $"windowTimeMicros".cast("long").as[Long], + $"count".as[Long]) + + testStream(dedupWithEventTime)( + AddData(inputData, 1, 5, 10, 15), + + // op1 W (0, 0) + // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // output: None + // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // op2 W (0, 0) + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 15) + // agg: None + // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 + // state: [15, 20) 1 + // op2 W (0, 15) + // output: (5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1) + // state: None - trimmed by watermark + + CheckNewAnswer((5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } + } + + private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q => + q.processAllAvailable() + val progressWithData = q.recentProgress.lastOption.get + val stateOperators = progressWithData.stateOperators + assert(stateOperators.size === numTotalRows.size) + assert(stateOperators.map(_.numRowsTotal).toSeq === numTotalRows) + true + } + + private def assertNumRowsDroppedByWatermark( + numRowsDroppedByWatermark: Seq[Long]): AssertOnQuery = AssertOnQuery { q => + q.processAllAvailable() + val progressWithData = q.recentProgress.filterNot { p => + // filter out batches which are falling into one of types: + // 1) doesn't execute the batch run + // 2) empty input batch + p.numInputRows == 0 + }.lastOption.get + val stateOperators = progressWithData.stateOperators + assert(stateOperators.size === numRowsDroppedByWatermark.size) + assert(stateOperators.map(_.numRowsDroppedByWatermark).toSeq === numRowsDroppedByWatermark) + true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 8ef8c21e13a33..40868f896f5ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -618,7 +618,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) assert(query.lastExecution.executedPlan.collect { - case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, + case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, ShuffleExchangeExec(opA: HashPartitioning, _, _), ShuffleExchangeExec(opB: HashPartitioning, _, _)) if partitionExpressionsColumns(opA.expressions) === Seq("a", "b")