diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 84795203fd174..06581e23d5854 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, CurrentDate, CurrentTimestampLike, GroupingSets, LocalTimestamp, MonotonicallyIncreasingID, SessionWindow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryComparison, CurrentDate, CurrentTimestampLike, Expression, GreaterThan, GreaterThanOrEqual, GroupingSets, LessThan, LessThanOrEqual, LocalTimestamp, MonotonicallyIncreasingID, SessionWindow} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} /** * Analyzes the presence of unsupported operations in a logical plan. @@ -42,40 +43,97 @@ object UnsupportedOperationChecker extends Logging { } /** - * Checks for possible correctness issue in chained stateful operators. The behavior is - * controlled by SQL config `spark.sql.streaming.statefulOperator.checkCorrectness.enabled`. - * Once it is enabled, an analysis exception will be thrown. Otherwise, Spark will just - * print a warning message. + * Checks if the expression has a event time column + * @param exp the expression to be checked + * @return true if it is a event time column. */ - def checkStreamingQueryGlobalWatermarkLimit( - plan: LogicalPlan, - outputMode: OutputMode): Unit = { - def isStatefulOperationPossiblyEmitLateRows(p: LogicalPlan): Boolean = p match { - case s: Aggregate - if s.isStreaming && outputMode == InternalOutputModes.Append => true - case Join(left, right, joinType, _, _) - if left.isStreaming && right.isStreaming && joinType != Inner => true - case f: FlatMapGroupsWithState - if f.isStreaming && f.outputMode == OutputMode.Append() => true - case _ => false + private def hasEventTimeCol(exp: Expression): Boolean = exp.exists { + case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) + case _ => false + } + + /** + * Checks if the expression contains a range comparison, in which + * either side of the comparison is an event-time column. This is used for checking + * stream-stream time interval join. + * @param e the expression to be checked + * @return true if there is a time-interval join. + */ + private def hasRangeExprAgainstEventTimeCol(e: Expression): Boolean = { + def hasEventTimeColBinaryComp(neq: Expression): Boolean = { + val exp = neq.asInstanceOf[BinaryComparison] + hasEventTimeCol(exp.left) || hasEventTimeCol(exp.right) } - def isStatefulOperation(p: LogicalPlan): Boolean = p match { - case s: Aggregate if s.isStreaming => true - case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true - case f: FlatMapGroupsWithState if f.isStreaming => true - case f: FlatMapGroupsInPandasWithState if f.isStreaming => true - case d: Deduplicate if d.isStreaming => true + e.exists { + case neq @ (_: LessThanOrEqual | _: LessThan | _: GreaterThanOrEqual | _: GreaterThan) => + hasEventTimeColBinaryComp(neq) case _ => false } + } - val failWhenDetected = SQLConf.get.statefulOperatorCorrectnessCheckEnabled + /** + * This method, combined with isStatefulOperation, determines all disallowed + * behaviors in multiple stateful operators. + * Concretely, All conditions defined below cannot be followed by any streaming stateful + * operator as defined in isStatefulOperation. + * @param p logical plan to be checked + * @param outputMode query output mode + * @return true if it is not allowed when followed by any streaming stateful + * operator as defined in isStatefulOperation. + */ + private def ifCannotBeFollowedByStatefulOperation( + p: LogicalPlan, outputMode: OutputMode): Boolean = p match { + case ExtractEquiJoinKeys(_, _, _, otherCondition, _, left, right, _) => + left.isStreaming && right.isStreaming && + otherCondition.isDefined && hasRangeExprAgainstEventTimeCol(otherCondition.get) + // FlatMapGroupsWithState configured with event time + case f @ FlatMapGroupsWithState(_, _, _, _, _, _, _, _, _, timeout, _, _, _, _, _, _) + if f.isStreaming && timeout == GroupStateTimeout.EventTimeTimeout => true + case p @ FlatMapGroupsInPandasWithState(_, _, _, _, _, timeout, _) + if p.isStreaming && timeout == GroupStateTimeout.EventTimeTimeout => true + case a: Aggregate if a.isStreaming && outputMode != InternalOutputModes.Append => true + // Since the Distinct node will be replaced to Aggregate in the optimizer rule + // [[ReplaceDistinctWithAggregate]], here we also need to check all Distinct node by + // assuming it as Aggregate. + case d @ Distinct(_: LogicalPlan) if d.isStreaming + && outputMode != InternalOutputModes.Append => true + case _ => false + } + + /** + * This method is only used with ifCannotBeFollowedByStatefulOperation. + * Here we list up stateful operators but there is an exception for Deduplicate: + * it is only counted here when it has an event time column. + * @param p the logical plan to be checked + * @return true if there is a streaming stateful operation + */ + private def isStatefulOperation(p: LogicalPlan): Boolean = p match { + case s: Aggregate if s.isStreaming => true + // Since the Distinct node will be replaced to Aggregate in the optimizer rule + // [[ReplaceDistinctWithAggregate]], here we also need to check all Distinct node by + // assuming it as Aggregate. + case d @ Distinct(_: LogicalPlan) if d.isStreaming => true + case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true + case f: FlatMapGroupsWithState if f.isStreaming => true + case f: FlatMapGroupsInPandasWithState if f.isStreaming => true + case d: Deduplicate if d.isStreaming && d.keys.exists(hasEventTimeCol) => true + case _ => false + } + /** + * Checks for possible correctness issue in chained stateful operators. The behavior is + * controlled by SQL config `spark.sql.streaming.statefulOperator.checkCorrectness.enabled`. + * Once it is enabled, an analysis exception will be thrown. Otherwise, Spark will just + * print a warning message. + */ + def checkStreamingQueryGlobalWatermarkLimit(plan: LogicalPlan, outputMode: OutputMode): Unit = { + val failWhenDetected = SQLConf.get.statefulOperatorCorrectnessCheckEnabled try { plan.foreach { subPlan => if (isStatefulOperation(subPlan)) { subPlan.find { p => - (p ne subPlan) && isStatefulOperationPossiblyEmitLateRows(p) + (p ne subPlan) && ifCannotBeFollowedByStatefulOperation(p, outputMode) }.foreach { _ => val errorMsg = "Detected pattern of possible 'correctness' issue " + "due to global watermark. " + @@ -154,15 +212,7 @@ object UnsupportedOperationChecker extends Logging { "DataFrames/Datasets")(plan) } - // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) - - if (aggregates.size > 1) { - throwError( - "Multiple streaming aggregations are not supported with " + - "streaming DataFrames/Datasets")(plan) - } - // Disallow some output mode outputMode match { case InternalOutputModes.Append if aggregates.nonEmpty => @@ -266,12 +316,8 @@ object UnsupportedOperationChecker extends Logging { " DataFrame/Dataset") } if (m.isMapGroupsWithState) { // check mapGroupsWithState - // allowed only in update query output mode and without aggregation - if (aggsInQuery.nonEmpty) { - throwError( - "mapGroupsWithState is not supported with aggregation " + - "on a streaming DataFrame/Dataset") - } else if (outputMode != InternalOutputModes.Update) { + // allowed only in update query output mode + if (outputMode != InternalOutputModes.Update) { throwError( "mapGroupsWithState is not supported with " + s"$outputMode output mode on a streaming DataFrame/Dataset") @@ -294,16 +340,11 @@ object UnsupportedOperationChecker extends Logging { case _ => } } else { - // flatMapGroupsWithState with aggregation: update operation mode not allowed, and - // *groupsWithState after aggregation not allowed + // flatMapGroupsWithState with aggregation: update operation mode not allowed if (m.outputMode == InternalOutputModes.Update) { throwError( "flatMapGroupsWithState in update mode is not supported with " + "aggregation on a streaming DataFrame/Dataset") - } else if (collectStreamingAggregates(m).nonEmpty) { - throwError( - "flatMapGroupsWithState in append mode is not supported after " + - "aggregation on a streaming DataFrame/Dataset") } } } @@ -373,10 +414,6 @@ object UnsupportedOperationChecker extends Logging { } } - case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => - throwError("dropDuplicates is not supported after aggregation on a " + - "streaming DataFrame/Dataset") - case j @ Join(left, right, joinType, condition, _) => if (left.isStreaming && right.isStreaming && outputMode != InternalOutputModes.Append) { throwError("Join between two streaming DataFrames/Datasets is not supported" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index d30bcd5af5dad..64c5ea3f5b19f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -100,12 +100,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { Aggregate(Nil, aggExprs("d"), streamRelation), joinType = Inner), Update) - assertNotSupportedInStreamingPlan( - "aggregate - multiple streaming aggregations", - Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), - outputMode = Update, - expectedMsgs = Seq("multiple streaming aggregations")) - assertSupportedInStreamingPlan( "aggregate - streaming aggregations in update mode", Aggregate(Nil, aggExprs("d"), streamRelation), @@ -233,17 +227,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "false") } - for (outputMode <- Seq(Append, Update)) { - assertNotSupportedInStreamingPlan( - "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + - s"on streaming relation after aggregation in $outputMode mode", - TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, - isMapGroupsWithState = false, null, - Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), - outputMode = outputMode, - expectedMsgs = Seq("flatMapGroupsWithState", "after aggregation")) - } - assertNotSupportedInStreamingPlan( "flatMapGroupsWithState - " + "flatMapGroupsWithState(Update) on streaming relation in complete mode", @@ -315,17 +298,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { // future. expectedMsgs = Seq("Complete")) - for (outputMode <- Seq(Append, Update, Complete)) { - assertNotSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation " + - s"with aggregation in $outputMode mode", - TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, - isMapGroupsWithState = true, null, - Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), - outputMode = outputMode, - expectedMsgs = Seq("mapGroupsWithState", "with aggregation")) - } - // multiple mapGroupsWithStates assertNotSupportedInStreamingPlan( "mapGroupsWithState - multiple mapGroupsWithStates on streaming relation and all are " + @@ -369,19 +341,13 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { // Deduplicate assertSupportedInStreamingPlan( - "Deduplicate - Deduplicate on streaming relation before aggregation", + "Deduplicate - Deduplicate on streaming relation before aggregation - append", Aggregate( Seq(attributeWithWatermark), aggExprs("c"), Deduplicate(Seq(att), streamRelation)), outputMode = Append) - assertNotSupportedInStreamingPlan( - "Deduplicate - Deduplicate on streaming relation after aggregation", - Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation)), - outputMode = Complete, - expectedMsgs = Seq("dropDuplicates")) - assertSupportedInStreamingPlan( "Deduplicate - Deduplicate on batch relation inside a streaming query", Deduplicate(Seq(att), batchRelation), @@ -501,51 +467,217 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { "the nullable side and an appropriate range condition")) } - // stream-stream inner join doesn't emit late rows, whereas outer joins could - Seq((Inner, false), (LeftOuter, true), (RightOuter, true)).foreach { - case (joinType, expectFailure) => + // multi-aggregations only supported in Append mode + assertPassOnGlobalWatermarkLimit( + "aggregate - multiple streaming aggregations - append", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Append) + + assertFailOnGlobalWatermarkLimit( + "aggregate - multiple streaming aggregations - update", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Update) + + assertFailOnGlobalWatermarkLimit( + "aggregate - multiple streaming aggregations - complete", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Complete) + + assertPassOnGlobalWatermarkLimit( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation after aggregation in Append mode", + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = Append) + + // Aggregation not in Append mode followed by any stateful operators is disallowed + assertFailOnGlobalWatermarkLimit( + "flatMapGroupsWithState - flatMapGroupsWithState(Append) " + + s"on streaming relation after aggregation in Update mode", + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = Update) + + // Aggregation not in Append mode followed by any stateful operators is disallowed + assertFailOnGlobalWatermarkLimit( + "mapGroupsWithState - mapGroupsWithState on streaming relation " + + "after aggregation in Update mode", + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + isMapGroupsWithState = true, null, + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = Update) + + // FlatMapGroupsWithState followed by any stateful op not allowed, here test aggregation + assertFailOnGlobalWatermarkLimit( + "multiple stateful ops - FlatMapGroupsWithState followed by agg", + Aggregate(Nil, aggExprs("c"), + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, GroupStateTimeout.EventTimeTimeout(), streamRelation)), + outputMode = Append) + + // But allows if the FlatMapGroupsWithState has timeout on processing time + assertPassOnGlobalWatermarkLimit( + "multiple stateful ops - FlatMapGroupsWithState(process time) followed by agg", + Aggregate(Nil, aggExprs("c"), + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, GroupStateTimeout.ProcessingTimeTimeout(), streamRelation)), + outputMode = Append) + + // MapGroupsWithState followed by any stateful op not allowed, here test aggregation + assertFailOnGlobalWatermarkLimit( + "multiple stateful ops - MapGroupsWithState followed by agg", + Aggregate(Nil, aggExprs("c"), + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + isMapGroupsWithState = true, GroupStateTimeout.EventTimeTimeout(), streamRelation)), + outputMode = Append) + + // But allows if the MapGroupsWithState has timeout on processing time + assertPassOnGlobalWatermarkLimit( + "multiple stateful ops - MapGroupsWithState(process time) followed by agg", + Aggregate(Nil, aggExprs("c"), + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Update, + isMapGroupsWithState = true, GroupStateTimeout.ProcessingTimeTimeout(), streamRelation)), + outputMode = Append) + + // stream-stream relation, time interval join can't be followed by any stateful operators + assertFailOnGlobalWatermarkLimit( + "multiple stateful ops - stream-stream time-interval join followed by agg", + Aggregate(Nil, aggExprs("c"), + streamRelation.join(streamRelation, joinType = Inner, + condition = Some(attribute === attribute && + attributeWithWatermark > attributeWithWatermark + 10))), + outputMode = Append) + + // stream-stream relation, only equality join can be followed by any stateful operators + assertPassOnGlobalWatermarkLimit( + "multiple stateful ops - stream-stream equality join followed by agg", + Aggregate(Nil, aggExprs("c"), + streamRelation.join(streamRelation, joinType = Inner, + condition = Some(attribute === attribute))), + outputMode = Append) + + // Deduplication checks: + // Deduplication, if on event time column, is a stateful operator + // and cannot be placed after FlatMapGroupsWithState + assertFailOnGlobalWatermarkLimit( + "multiple stateful ops - FlatMapGroupsWithState followed by " + + "dedup (with event-time)", + Deduplicate(Seq(attributeWithWatermark), + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, GroupStateTimeout.EventTimeTimeout(), streamRelation)), + outputMode = Append) + + // Deduplication, if not on event time column, + // although it is still a stateful operator, + // it can be placed after FlatMapGroupsWithState + assertPassOnGlobalWatermarkLimit( + "multiple stateful ops - FlatMapGroupsWithState followed by " + + "dedup (without event-time)", + Deduplicate(Seq(att), + TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, + isMapGroupsWithState = false, null, streamRelation)), + outputMode = Append) + + // Deduplication, if on event time column, is a stateful operator + // and cannot be placed after aggregation + for (outputMode <- Seq(Update, Complete)) { + assertFailOnGlobalWatermarkLimit( + s"multiple stateful ops - aggregation($outputMode mode) followed by " + + "dedup (with event-time)", + Deduplicate(Seq(attributeWithWatermark), + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode) + + // Deduplication, if not on event time column, + // although it is still a stateful operator, + // it can be placed after aggregation + assertPassOnGlobalWatermarkLimit( + s"multiple stateful ops - aggregation($outputMode mode) followed by " + + "dedup (without event-time)", + Deduplicate(Seq(att), + Aggregate(Seq(attributeWithWatermark), aggExprs("c"), streamRelation)), + outputMode = outputMode) + } + + // Deduplication, if on event time column, is a stateful operator + // and cannot be placed after join + assertFailOnGlobalWatermarkLimit( + "multiple stateful ops - stream-stream time interval join followed by" + + "dedup (with event-time)", + Deduplicate(Seq(attributeWithWatermark), + streamRelation.join(streamRelation, joinType = Inner, + condition = Some(attribute === attribute && + attributeWithWatermark > attributeWithWatermark + 10))), + outputMode = Append) + + // Deduplication, if not on event time column, + // although it is still a stateful operator, + // it can be placed after join + assertPassOnGlobalWatermarkLimit( + "multiple stateful ops - stream-stream time interval join followed by" + + "dedup (without event-time)", + Deduplicate(Seq(att), + streamRelation.join(streamRelation, joinType = Inner, + condition = Some(attribute === attribute && + attributeWithWatermark > attributeWithWatermark + 10))), + outputMode = Append) + + // for a stream-stream join followed by a stateful operator, + // if the join is keyed on time-interval inequality conditions (inequality on watermarked cols), + // should fail. + // if the join is keyed on time-interval equality conditions -> should pass + Seq(Inner, LeftOuter, RightOuter, FullOuter).foreach { + joinType => + assertFailOnGlobalWatermarkLimit( + s"streaming aggregation after " + + s"stream-stream $joinType join keyed on time inequality in Append mode are not supported", + streamRelation.join(streamRelation, joinType = joinType, + condition = Some(attributeWithWatermark === attribute && + attributeWithWatermark < attributeWithWatermark + 10)) + .groupBy("a")(count("*")), + outputMode = Append) + assertPassOnGlobalWatermarkLimit( s"single $joinType join in Append mode", streamRelation.join(streamRelation, joinType = RightOuter, condition = Some(attributeWithWatermark === attribute)), - OutputMode.Append()) + outputMode = Append) - testGlobalWatermarkLimit( - s"streaming aggregation after stream-stream $joinType join in Append mode", + assertPassOnGlobalWatermarkLimit( + s"streaming aggregation after " + + s"stream-stream $joinType join keyed on time equality in Append mode are supported", streamRelation.join(streamRelation, joinType = joinType, condition = Some(attributeWithWatermark === attribute)) .groupBy("a")(count("*")), - OutputMode.Append(), - expectFailure = expectFailure) + outputMode = Append) Seq(Inner, LeftOuter, RightOuter).foreach { joinType2 => - testGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( s"streaming-stream $joinType2 after stream-stream $joinType join in Append mode", streamRelation.join( streamRelation.join(streamRelation, joinType = joinType, condition = Some(attributeWithWatermark === attribute)), joinType = joinType2, condition = Some(attributeWithWatermark === attribute)), - OutputMode.Append(), - expectFailure = expectFailure) + outputMode = Append) } - testGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( s"FlatMapGroupsWithState after stream-stream $joinType join in Append mode", TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, streamRelation.join(streamRelation, joinType = joinType, condition = Some(attributeWithWatermark === attribute))), - OutputMode.Append(), - expectFailure = expectFailure) + outputMode = Append) - testGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( s"deduplicate after stream-stream $joinType join in Append mode", Deduplicate(Seq(attribute), streamRelation.join(streamRelation, joinType = joinType, condition = Some(attributeWithWatermark === attribute))), - OutputMode.Append(), - expectFailure = expectFailure) + outputMode = Append) } // Cogroup: only batch-batch is allowed @@ -635,40 +767,36 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { null, null, null, - new TestStreamingRelationV2(attribute)), OutputMode.Append()) + new TestStreamingRelationV2(attribute)), outputMode = Append) // streaming aggregation { assertPassOnGlobalWatermarkLimit( "single streaming aggregation in Append mode", - streamRelation.groupBy("a")(count("*")), - OutputMode.Append()) + streamRelation.groupBy("a")(count("*")), outputMode = Append) - assertFailOnGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( "chained streaming aggregations in Append mode", - streamRelation.groupBy("a")(count("*")).groupBy()(count("*")), - OutputMode.Append()) + streamRelation.groupBy("a")(count("*")).groupBy()(count("*")), outputMode = Append) Seq(Inner, LeftOuter, RightOuter).foreach { joinType => val plan = streamRelation.join(streamRelation.groupBy("a")(count("*")), joinType = joinType) - assertFailOnGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( s"$joinType join after streaming aggregation in Append mode", streamRelation.join(streamRelation.groupBy("a")(count("*")), joinType = joinType), OutputMode.Append()) } - assertFailOnGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( "deduplicate after streaming aggregation in Append mode", - Deduplicate(Seq(attribute), streamRelation.groupBy("a")(count("*"))), - OutputMode.Append()) + Deduplicate(Seq(attribute), streamRelation.groupBy("a")(count("*"))), OutputMode.Append()) - assertFailOnGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( "FlatMapGroupsWithState after streaming aggregation in Append mode", TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, - streamRelation.groupBy("a")(count("*"))), - OutputMode.Append()) + streamRelation.groupBy("a")(count("*"))), outputMode = Append) } // FlatMapGroupsWithState @@ -677,24 +805,23 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { "single FlatMapGroupsWithState in Append mode", TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, - isMapGroupsWithState = false, null, streamRelation), - OutputMode.Append()) + isMapGroupsWithState = false, null, streamRelation), outputMode = Append) assertFailOnGlobalWatermarkLimit( "streaming aggregation after FlatMapGroupsWithState in Append mode", TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, - isMapGroupsWithState = false, null, streamRelation).groupBy("*")(count("*")), - OutputMode.Append()) + isMapGroupsWithState = false, GroupStateTimeout.EventTimeTimeout(), + streamRelation).groupBy("*")(count("*")), outputMode = Append) Seq(Inner, LeftOuter, RightOuter).foreach { joinType => assertFailOnGlobalWatermarkLimit( s"stream-stream $joinType after FlatMapGroupsWithState in Append mode", streamRelation.join( TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, - isMapGroupsWithState = false, null, streamRelation), joinType = joinType, - condition = Some(attributeWithWatermark === attribute)), - OutputMode.Append()) + isMapGroupsWithState = false, GroupStateTimeout.EventTimeTimeout(), + streamRelation), joinType = joinType, + condition = Some(attributeWithWatermark === attribute)), outputMode = Append) } assertFailOnGlobalWatermarkLimit( @@ -702,30 +829,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, - isMapGroupsWithState = false, null, streamRelation)), - OutputMode.Append()) + isMapGroupsWithState = false, GroupStateTimeout.EventTimeTimeout(), streamRelation)), + outputMode = Append) - assertFailOnGlobalWatermarkLimit( + assertPassOnGlobalWatermarkLimit( s"deduplicate after FlatMapGroupsWithState in Append mode", Deduplicate(Seq(attribute), TestFlatMapGroupsWithState(null, att, att, Seq(att), Seq(att), att, null, Append, - isMapGroupsWithState = false, null, streamRelation)), - OutputMode.Append()) + isMapGroupsWithState = false, null, streamRelation)), outputMode = Append) } // deduplicate { assertPassOnGlobalWatermarkLimit( "streaming aggregation after deduplicate in Append mode", - Deduplicate(Seq(attribute), streamRelation).groupBy("a")(count("*")), - OutputMode.Append()) + Deduplicate(Seq(attribute), streamRelation).groupBy("a")(count("*")), outputMode = Append) Seq(Inner, LeftOuter, RightOuter).foreach { joinType => assertPassOnGlobalWatermarkLimit( s"$joinType join after deduplicate in Append mode", streamRelation.join(Deduplicate(Seq(attribute), streamRelation), joinType = joinType, - condition = Some(attributeWithWatermark === attribute)), - OutputMode.Append()) + condition = Some(attributeWithWatermark === attribute)), outputMode = Append) } assertPassOnGlobalWatermarkLimit( @@ -733,8 +857,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { TestFlatMapGroupsWithState( null, att, att, Seq(att), Seq(att), att, null, Append, isMapGroupsWithState = false, null, - Deduplicate(Seq(attribute), streamRelation)), - OutputMode.Append()) + Deduplicate(Seq(attribute), streamRelation)), outputMode = Append) } /* @@ -941,21 +1064,21 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { testNamePostfix: String, plan: LogicalPlan, outputMode: OutputMode): Unit = { - testGlobalWatermarkLimit(testNamePostfix, plan, outputMode, expectFailure = false) + testGlobalWatermarkLimit(testNamePostfix, plan, expectFailure = false, outputMode) } def assertFailOnGlobalWatermarkLimit( testNamePostfix: String, plan: LogicalPlan, outputMode: OutputMode): Unit = { - testGlobalWatermarkLimit(testNamePostfix, plan, outputMode, expectFailure = true) + testGlobalWatermarkLimit(testNamePostfix, plan, expectFailure = true, outputMode) } def testGlobalWatermarkLimit( testNamePostfix: String, plan: LogicalPlan, - outputMode: OutputMode, - expectFailure: Boolean): Unit = { + expectFailure: Boolean, + outputMode: OutputMode): Unit = { test(s"Global watermark limit - $testNamePostfix") { if (expectFailure) { withSQLConf(SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "true") { @@ -966,10 +1089,8 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { assert(e.message.contains("Detected pattern of possible 'correctness' issue")) } } else { - withSQLConf(SQLConf.STATEFUL_OPERATOR_CHECK_CORRECTNESS_ENABLED.key -> "false") { - UnsupportedOperationChecker.checkStreamingQueryGlobalWatermarkLimit( - wrapInStreaming(plan), outputMode) - } + UnsupportedOperationChecker.checkStreamingQueryGlobalWatermarkLimit( + wrapInStreaming(plan), outputMode) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala index 0a3ea40a677ad..eb1e0de79cae7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MultiStatefulOperatorsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ @@ -40,378 +40,427 @@ class MultiStatefulOperatorsSuite } test("window agg -> window agg, append mode") { - // TODO: SPARK-40940 - Fix the unsupported ops checker to allow chaining of stateful ops. - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val inputData = MemoryStream[Int] - - val stream = inputData.toDF() - .withColumn("eventTime", timestamp_seconds($"value")) - .withWatermark("eventTime", "0 seconds") - .groupBy(window($"eventTime", "5 seconds").as("window")) - .agg(count("*").as("count")) - .groupBy(window($"window", "10 seconds")) - .agg(count("*").as("count"), sum("count").as("sum")) - .select($"window".getField("start").cast("long").as[Long], - $"count".as[Long], $"sum".as[Long]) - - testStream(stream)( - AddData(inputData, 10 to 21: _*), - // op1 W (0, 0) - // agg: [10, 15) 5, [15, 20) 5, [20, 25) 2 - // output: None - // state: [10, 15) 5, [15, 20) 5, [20, 25) 2 - // op2 W (0, 0) - // agg: None - // output: None - // state: None - - // no-data batch triggered - - // op1 W (0, 21) - // agg: None - // output: [10, 15) 5, [15, 20) 5 - // state: [20, 25) 2 - // op2 W (0, 21) - // agg: [10, 20) (2, 10) - // output: [10, 20) (2, 10) - // state: None - CheckNewAnswer((10, 2, 10)), - assertNumStateRows(Seq(0, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 0)), - - AddData(inputData, 10 to 29: _*), - // op1 W (21, 21) - // agg: [10, 15) 5 - late, [15, 20) 5 - late, [20, 25) 5, [25, 30) 5 - // output: None - // state: [20, 25) 7, [25, 30) 5 - // op2 W (21, 21) - // agg: None - // output: None - // state: None - - // no-data batch triggered - - // op1 W (21, 29) - // agg: None - // output: [20, 25) 7 - // state: [25, 30) 5 - // op2 W (21, 29) - // agg: [20, 30) (1, 7) - // output: None - // state: [20, 30) (1, 7) - CheckNewAnswer(), - assertNumStateRows(Seq(1, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 2)), - - // Move the watermark. - AddData(inputData, 30, 31), - // op1 W (29, 29) - // agg: [30, 35) 2 - // output: None - // state: [25, 30) 5 [30, 35) 2 - // op2 W (29, 29) - // agg: None - // output: None - // state: [20, 30) (1, 7) - - // no-data batch triggered - - // op1 W (29, 31) - // agg: None - // output: [25, 30) 5 - // state: [30, 35) 2 - // op2 W (29, 31) - // agg: [20, 30) (2, 12) - // output: [20, 30) (2, 12) - // state: None - CheckNewAnswer((20, 2, 12)), - assertNumStateRows(Seq(0, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 0)) - ) - } + val inputData = MemoryStream[Int] + + val stream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .groupBy(window($"window", "10 seconds")) + .agg(count("*").as("count"), sum("count").as("sum")) + .select($"window".getField("start").cast("long").as[Long], + $"count".as[Long], $"sum".as[Long]) + + testStream(stream)( + AddData(inputData, 10 to 21: _*), + // op1 W (0, 0) + // agg: [10, 15) 5, [15, 20) 5, [20, 25) 2 + // output: None + // state: [10, 15) 5, [15, 20) 5, [20, 25) 2 + // op2 W (0, 0) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 21) + // agg: None + // output: [10, 15) 5, [15, 20) 5 + // state: [20, 25) 2 + // op2 W (0, 21) + // agg: [10, 20) (2, 10) + // output: [10, 20) (2, 10) + // state: None + CheckNewAnswer((10, 2, 10)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)), + + AddData(inputData, 10 to 29: _*), + // op1 W (21, 21) + // agg: [10, 15) 5 - late, [15, 20) 5 - late, [20, 25) 5, [25, 30) 5 + // output: None + // state: [20, 25) 7, [25, 30) 5 + // op2 W (21, 21) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (21, 29) + // agg: None + // output: [20, 25) 7 + // state: [25, 30) 5 + // op2 W (21, 29) + // agg: [20, 30) (1, 7) + // output: None + // state: [20, 30) (1, 7) + CheckNewAnswer(), + assertNumStateRows(Seq(1, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 2)), + + // Move the watermark. + AddData(inputData, 30, 31), + // op1 W (29, 29) + // agg: [30, 35) 2 + // output: None + // state: [25, 30) 5 [30, 35) 2 + // op2 W (29, 29) + // agg: None + // output: None + // state: [20, 30) (1, 7) + + // no-data batch triggered + + // op1 W (29, 31) + // agg: None + // output: [25, 30) 5 + // state: [30, 35) 2 + // op2 W (29, 31) + // agg: [20, 30) (2, 12) + // output: [20, 30) (2, 12) + // state: None + CheckNewAnswer((20, 2, 12)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) } test("agg -> agg -> agg, append mode") { - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val inputData = MemoryStream[Int] - - val stream = inputData.toDF() - .withColumn("eventTime", timestamp_seconds($"value")) - .withWatermark("eventTime", "0 seconds") - .groupBy(window($"eventTime", "5 seconds").as("window")) - .agg(count("*").as("count")) - .groupBy(window(window_time($"window"), "10 seconds")) - .agg(count("*").as("count"), sum("count").as("sum")) - .groupBy(window(window_time($"window"), "20 seconds")) - .agg(count("*").as("count"), sum("sum").as("sum")) - .select( - $"window".getField("start").cast("long").as[Long], - $"window".getField("end").cast("long").as[Long], - $"count".as[Long], $"sum".as[Long]) - - testStream(stream)( - AddData(inputData, 0 to 37: _*), - // op1 W (0, 0) - // agg: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, - // [35, 40) 3 - // output: None - // state: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, - // [35, 40) 3 - // op2 W (0, 0) - // agg: None - // output: None - // state: None - // op3 W (0, 0) - // agg: None - // output: None - // state: None - - // no-data batch triggered - - // op1 W (0, 37) - // agg: None - // output: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5 - // state: [35, 40) 3 - // op2 W (0, 37) - // agg: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10), [30, 40) (1, 5) - // output: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10) - // state: [30, 40) (1, 5) - // op3 W (0, 37) - // agg: [0, 20) (2, 20), [20, 40) (1, 10) - // output: [0, 20) (2, 20) - // state: [20, 40) (1, 10) - CheckNewAnswer((0, 20, 2, 20)), - assertNumStateRows(Seq(1, 1, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 0, 0)), - - AddData(inputData, 30 to 60: _*), - // op1 W (37, 37) - // dropped rows: [30, 35), 1 row <= note that 35, 36, 37 are still in effect - // agg: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 - // output: None - // state: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 - // op2 W (37, 37) - // output: None - // state: [30, 40) (1, 5) - // op3 W (37, 37) - // output: None - // state: [20, 40) (1, 10) - - // no-data batch - // op1 W (37, 60) - // output: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5 - // state: [60, 65) 1 - // op2 W (37, 60) - // agg: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) - // output: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) - // state: None - // op3 W (37, 60) - // agg: [20, 40) (2, 23), [40, 60) (2, 20) - // output: [20, 40) (2, 23), [40, 60) (2, 20) - // state: None - - CheckNewAnswer((20, 40, 2, 23), (40, 60, 2, 20)), - assertNumStateRows(Seq(0, 0, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 0, 1)) - ) - } + val inputData = MemoryStream[Int] + + val stream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .groupBy(window(window_time($"window"), "10 seconds")) + .agg(count("*").as("count"), sum("count").as("sum")) + .groupBy(window(window_time($"window"), "20 seconds")) + .agg(count("*").as("count"), sum("sum").as("sum")) + .select( + $"window".getField("start").cast("long").as[Long], + $"window".getField("end").cast("long").as[Long], + $"count".as[Long], $"sum".as[Long]) + + testStream(stream)( + AddData(inputData, 0 to 37: _*), + // op1 W (0, 0) + // agg: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, + // [35, 40) 3 + // output: None + // state: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5, + // [35, 40) 3 + // op2 W (0, 0) + // agg: None + // output: None + // state: None + // op3 W (0, 0) + // agg: None + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 37) + // agg: None + // output: [0, 5) 5, [5, 10) 5, [10, 15) 5, [15, 20) 5, [20, 25) 5, [25, 30) 5, [30, 35) 5 + // state: [35, 40) 3 + // op2 W (0, 37) + // agg: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10), [30, 40) (1, 5) + // output: [0, 10) (2, 10), [10, 20) (2, 10), [20, 30) (2, 10) + // state: [30, 40) (1, 5) + // op3 W (0, 37) + // agg: [0, 20) (2, 20), [20, 40) (1, 10) + // output: [0, 20) (2, 20) + // state: [20, 40) (1, 10) + CheckNewAnswer((0, 20, 2, 20)), + assertNumStateRows(Seq(1, 1, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0, 0)), + + AddData(inputData, 30 to 60: _*), + // op1 W (37, 37) + // dropped rows: [30, 35), 1 row <= note that 35, 36, 37 are still in effect + // agg: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 + // output: None + // state: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5, [60, 65) 1 + // op2 W (37, 37) + // output: None + // state: [30, 40) (1, 5) + // op3 W (37, 37) + // output: None + // state: [20, 40) (1, 10) + + // no-data batch + // op1 W (37, 60) + // output: [35, 40) 8, [40, 45) 5, [45, 50) 5, [50, 55) 5, [55, 60) 5 + // state: [60, 65) 1 + // op2 W (37, 60) + // agg: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) + // output: [30, 40) (2, 13), [40, 50) (2, 10), [50, 60), (2, 10) + // state: None + // op3 W (37, 60) + // agg: [20, 40) (2, 23), [40, 60) (2, 20) + // output: [20, 40) (2, 23), [40, 60) (2, 20) + // state: None + + CheckNewAnswer((20, 40, 2, 23), (40, 60, 2, 20)), + assertNumStateRows(Seq(0, 0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0, 1)) + ) } test("stream deduplication -> aggregation, append mode") { - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val inputData = MemoryStream[Int] - - val deduplication = inputData.toDF() - .withColumn("eventTime", timestamp_seconds($"value")) - .withWatermark("eventTime", "10 seconds") - .dropDuplicates("value", "eventTime") - - val windowedAggregation = deduplication - .groupBy(window($"eventTime", "5 seconds").as("window")) - .agg(count("*").as("count"), sum("value").as("sum")) - .select($"window".getField("start").cast("long").as[Long], - $"count".as[Long]) - - testStream(windowedAggregation)( - AddData(inputData, 1 to 15: _*), - // op1 W (0, 0) - // input: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - // deduplicated: None - // output: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - // state: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - // op2 W (0, 0) - // agg: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 - // output: None - // state: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 - - // no-data batch triggered - - // op1 W (0, 5) - // agg: None - // output: None - // state: 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - // op2 W (0, 5) - // agg: None - // output: [0, 5) 4 - // state: [5, 10) 5 [10, 15) 5, [15, 20) 1 - CheckNewAnswer((0, 4)), - assertNumStateRows(Seq(3, 10)), - assertNumRowsDroppedByWatermark(Seq(0, 0)) - ) - } + val inputData = MemoryStream[Int] + + val deduplication = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates("value", "eventTime") + + val windowedAggregation = deduplication + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count"), sum("value").as("sum")) + .select($"window".getField("start").cast("long").as[Long], + $"count".as[Long]) + + testStream(windowedAggregation)( + AddData(inputData, 1 to 15: _*), + // op1 W (0, 0) + // input: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // deduplicated: None + // output: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // state: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // op2 W (0, 0) + // agg: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 + // output: None + // state: [0, 5) 4, [5, 10) 5 [10, 15) 5, [15, 20) 1 + + // no-data batch triggered + + // op1 W (0, 5) + // agg: None + // output: None + // state: 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + // op2 W (0, 5) + // agg: None + // output: [0, 5) 4 + // state: [5, 10) 5 [10, 15) 5, [15, 20) 1 + CheckNewAnswer((0, 4)), + assertNumStateRows(Seq(3, 10)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) } test("join -> window agg, append mode") { - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val input1 = MemoryStream[Int] - val inputDF1 = input1.toDF - .withColumnRenamed("value", "value1") - .withColumn("eventTime1", timestamp_seconds($"value1")) - .withWatermark("eventTime1", "0 seconds") - - val input2 = MemoryStream[Int] - val inputDF2 = input2.toDF - .withColumnRenamed("value", "value2") - .withColumn("eventTime2", timestamp_seconds($"value2")) - .withWatermark("eventTime2", "0 seconds") - - val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") - .groupBy(window($"eventTime1", "5 seconds").as("window")) - .agg(count("*").as("count")) - .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) - - testStream(stream)( - MultiAddData(input1, 1 to 4: _*)(input2, 1 to 4: _*), - - // op1 W (0, 0) - // join output: (1, 1), (2, 2), (3, 3), (4, 4) - // state: (1, 1), (2, 2), (3, 3), (4, 4) - // op2 W (0, 0) - // agg: [0, 5) 4 - // output: None - // state: [0, 5) 4 - - // no-data batch triggered - - // op1 W (0, 4) - // join output: None - // state: None - // op2 W (0, 4) - // agg: None - // output: None - // state: [0, 5) 4 - CheckNewAnswer(), - assertNumStateRows(Seq(1, 0)), - assertNumRowsDroppedByWatermark(Seq(0, 0)), - - // Move the watermark - MultiAddData(input1, 5)(input2, 5), - - // op1 W (4, 4) - // join output: (5, 5) - // state: (5, 5) - // op2 W (4, 4) - // agg: [5, 10) 1 - // output: None - // state: [0, 5) 4, [5, 10) 1 - - // no-data batch triggered - - // op1 W (4, 5) - // join output: None - // state: None - // op2 W (4, 5) - // agg: None - // output: [0, 5) 4 - // state: [5, 10) 1 - CheckNewAnswer((0, 4)), - assertNumStateRows(Seq(1, 0)), - assertNumRowsDroppedByWatermark(Seq(0, 0)) - ) - } + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "0 seconds") + + val input2 = MemoryStream[Int] + val inputDF2 = input2.toDF() + .withColumnRenamed("value", "value2") + .withColumn("eventTime2", timestamp_seconds($"value2")) + .withWatermark("eventTime2", "0 seconds") + + val stream = inputDF1.join(inputDF2, expr("eventTime1 = eventTime2"), "inner") + .groupBy(window($"eventTime1", "5 seconds").as("window")) + .agg(count("*").as("count")) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(stream)( + MultiAddData(input1, 1 to 4: _*)(input2, 1 to 4: _*), + + // op1 W (0, 0) + // join output: (1, 1), (2, 2), (3, 3), (4, 4) + // state: (1, 1), (2, 2), (3, 3), (4, 4) + // op2 W (0, 0) + // agg: [0, 5) 4 + // output: None + // state: [0, 5) 4 + + // no-data batch triggered + + // op1 W (0, 4) + // join output: None + // state: None + // op2 W (0, 4) + // agg: None + // output: None + // state: [0, 5) 4 + CheckNewAnswer(), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)), + + // Move the watermark + MultiAddData(input1, 5)(input2, 5), + + // op1 W (4, 4) + // join output: (5, 5) + // state: (5, 5) + // op2 W (4, 4) + // agg: [5, 10) 1 + // output: None + // state: [0, 5) 4, [5, 10) 1 + + // no-data batch triggered + + // op1 W (4, 5) + // join output: None + // state: None + // op2 W (4, 5) + // agg: None + // output: [0, 5) 4 + // state: [5, 10) 1 + CheckNewAnswer((0, 4)), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) } test("aggregation -> stream deduplication, append mode") { - withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") { - val inputData = MemoryStream[Int] - - val aggStream = inputData.toDF() - .withColumn("eventTime", timestamp_seconds($"value")) - .withWatermark("eventTime", "0 seconds") - .groupBy(window($"eventTime", "5 seconds").as("window")) - .agg(count("*").as("count")) - .withColumn("windowEnd", expr("window.end")) - - // dropDuplicates from aggStream without event time column for dropDuplicates - the - // state does not get trimmed due to watermark advancement. - val dedupNoEventTime = aggStream - .dropDuplicates("count", "windowEnd") - .select( - $"windowEnd".cast("long").as[Long], - $"count".as[Long]) - - testStream(dedupNoEventTime)( - AddData(inputData, 1, 5, 10, 15), - - // op1 W (0, 0) - // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 - // output: None - // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 - // op2 W (0, 0) - // output: None - // state: None - - // no-data batch triggered - - // op1 W (0, 15) - // agg: None - // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 - // state: [15, 20) 1 - // op2 W (0, 15) - // output: (5, 1), (10, 1), (15, 1) - // state: (5, 1), (10, 1), (15, 1) - - CheckNewAnswer((5, 1), (10, 1), (15, 1)), - assertNumStateRows(Seq(3, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 0)) - ) + val inputData = MemoryStream[Int] + + val aggStream = inputData.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 seconds") + .groupBy(window($"eventTime", "5 seconds").as("window")) + .agg(count("*").as("count")) + .withColumn("windowEnd", expr("window.end")) + + // dropDuplicates from aggStream without event time column for dropDuplicates - the + // state does not get trimmed due to watermark advancement. + val dedupNoEventTime = aggStream + .dropDuplicates("count", "windowEnd") + .select( + $"windowEnd".cast("long").as[Long], + $"count".as[Long]) + + testStream(dedupNoEventTime)( + AddData(inputData, 1, 5, 10, 15), + + // op1 W (0, 0) + // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // output: None + // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // op2 W (0, 0) + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 15) + // agg: None + // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 + // state: [15, 20) 1 + // op2 W (0, 15) + // output: (5, 1), (10, 1), (15, 1) + // state: (5, 1), (10, 1), (15, 1) + + CheckNewAnswer((5, 1), (10, 1), (15, 1)), + assertNumStateRows(Seq(3, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + + // Similar to the above but add event time. The dedup state will get trimmed. + val dedupWithEventTime = aggStream + .withColumn("windowTime", expr("window_time(window)")) + .withColumn("windowTimeMicros", expr("unix_micros(windowTime)")) + .dropDuplicates("count", "windowEnd", "windowTime") + .select( + $"windowEnd".cast("long").as[Long], + $"windowTimeMicros".cast("long").as[Long], + $"count".as[Long]) + + testStream(dedupWithEventTime)( + AddData(inputData, 1, 5, 10, 15), + + // op1 W (0, 0) + // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // output: None + // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 + // op2 W (0, 0) + // output: None + // state: None + + // no-data batch triggered + + // op1 W (0, 15) + // agg: None + // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 + // state: [15, 20) 1 + // op2 W (0, 15) + // output: (5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1) + // state: None - trimmed by watermark + + CheckNewAnswer((5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1)), + assertNumStateRows(Seq(0, 1)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) + } - // Similar to the above but add event time. The dedup state will get trimmed. - val dedupWithEventTime = aggStream - .withColumn("windowTime", expr("window_time(window)")) - .withColumn("windowTimeMicros", expr("unix_micros(windowTime)")) - .dropDuplicates("count", "windowEnd", "windowTime") - .select( - $"windowEnd".cast("long").as[Long], - $"windowTimeMicros".cast("long").as[Long], - $"count".as[Long]) - - testStream(dedupWithEventTime)( - AddData(inputData, 1, 5, 10, 15), - - // op1 W (0, 0) - // agg: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 - // output: None - // state: [0, 5) 1, [5, 10) 1, [10, 15) 1, [15, 20) 1 - // op2 W (0, 0) - // output: None - // state: None - - // no-data batch triggered - - // op1 W (0, 15) - // agg: None - // output: [0, 5) 1, [5, 10) 1, [10, 15) 1 - // state: [15, 20) 1 - // op2 W (0, 15) - // output: (5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1) - // state: None - trimmed by watermark - - CheckNewAnswer((5, 4999999, 1), (10, 9999999, 1), (15, 14999999, 1)), - assertNumStateRows(Seq(0, 1)), - assertNumRowsDroppedByWatermark(Seq(0, 0)) + test("join on time interval -> window agg, append mode, should fail") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "0 seconds") + + val input2 = MemoryStream[(Int, Int)] + val inputDF2 = input2.toDS().toDF("start", "end") + .withColumn("eventTime2Start", timestamp_seconds($"start")) + .withColumn("eventTime2End", timestamp_seconds($"end")) + .withColumn("start2", timestamp_seconds($"start")) + .withWatermark("eventTime2Start", "0 seconds") + + val stream = inputDF1.join(inputDF2, + expr("eventTime1 >= eventTime2Start AND eventTime1 < eventTime2End " + + "AND eventTime1 = start2"), "inner") + .groupBy(window($"eventTime1", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + val e = intercept[AnalysisException] { + testStream(stream)( + StartStream() ) } + assert(e.getMessage.contains("Detected pattern of possible 'correctness' issue")) + } + + test("join with range join on non-time intervals -> window agg, append mode, shouldn't fail") { + val input1 = MemoryStream[Int] + val inputDF1 = input1.toDF() + .withColumnRenamed("value", "value1") + .withColumn("eventTime1", timestamp_seconds($"value1")) + .withColumn("v1", timestamp_seconds($"value1")) + .withWatermark("eventTime1", "0 seconds") + + val input2 = MemoryStream[(Int, Int)] + val inputDF2 = input2.toDS().toDF("start", "end") + .withColumn("eventTime2Start", timestamp_seconds($"start")) + .withColumn("start2", timestamp_seconds($"start")) + .withColumn("end2", timestamp_seconds($"end")) + .withWatermark("eventTime2Start", "0 seconds") + + val stream = inputDF1.join(inputDF2, + expr("v1 >= start2 AND v1 < end2 " + + "AND eventTime1 = start2"), "inner") + .groupBy(window($"eventTime1", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(stream)( + AddData(input1, 1, 2, 3, 4), + AddData(input2, (1, 2), (2, 3), (3, 4), (4, 5)), + CheckNewAnswer(), + assertNumStateRows(Seq(1, 0)), + assertNumRowsDroppedByWatermark(Seq(0, 0)) + ) } private def assertNumStateRows(numTotalRows: Seq[Long]): AssertOnQuery = AssertOnQuery { q =>