Skip to content

Commit d722b2c

Browse files
committed
Revert the changes in non stream-stream join operators
1 parent adfe796 commit d722b2c

File tree

8 files changed

+21
-42
lines changed

8 files changed

+21
-42
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.aggregate
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.aggregate._
23-
import org.apache.spark.sql.catalyst.plans.physical.Distribution
2423
import org.apache.spark.sql.execution.SparkPlan
2524
import org.apache.spark.sql.execution.streaming._
2625
import org.apache.spark.sql.internal.SQLConf
@@ -47,7 +46,6 @@ object AggUtils {
4746
}
4847

4948
private def createAggregate(
50-
requiredChildDistributionOption: Option[Seq[Distribution]] = None,
5149
requiredChildDistributionExpressions: Option[Seq[Expression]] = None,
5250
groupingExpressions: Seq[NamedExpression] = Nil,
5351
aggregateExpressions: Seq[AggregateExpression] = Nil,
@@ -61,7 +59,6 @@ object AggUtils {
6159

6260
if (useHash && !forceSortAggregate) {
6361
HashAggregateExec(
64-
requiredChildDistributionOption = requiredChildDistributionOption,
6562
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
6663
groupingExpressions = groupingExpressions,
6764
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
@@ -75,7 +72,6 @@ object AggUtils {
7572

7673
if (objectHashEnabled && useObjectHash && !forceSortAggregate) {
7774
ObjectHashAggregateExec(
78-
requiredChildDistributionOption = requiredChildDistributionOption,
7975
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
8076
groupingExpressions = groupingExpressions,
8177
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
@@ -85,7 +81,6 @@ object AggUtils {
8581
child = child)
8682
} else {
8783
SortAggregateExec(
88-
requiredChildDistributionOption = requiredChildDistributionOption,
8984
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
9085
groupingExpressions = groupingExpressions,
9186
aggregateExpressions = mayRemoveAggFilters(aggregateExpressions),
@@ -304,16 +299,12 @@ object AggUtils {
304299
child = child)
305300
}
306301

307-
// This is used temporarily to pick up the required child distribution for the stateful
308-
// operator.
309-
val tempRestored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion,
310-
partialAggregate)
311-
312302
val partialMerged1: SparkPlan = {
313303
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
314304
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
315305
createAggregate(
316-
requiredChildDistributionOption = Some(tempRestored.requiredChildDistribution),
306+
requiredChildDistributionExpressions =
307+
Some(groupingAttributes),
317308
groupingExpressions = groupingAttributes,
318309
aggregateExpressions = aggregateExpressions,
319310
aggregateAttributes = aggregateAttributes,
@@ -330,7 +321,8 @@ object AggUtils {
330321
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
331322
val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
332323
createAggregate(
333-
requiredChildDistributionOption = Some(restored.requiredChildDistribution),
324+
requiredChildDistributionExpressions =
325+
Some(groupingAttributes),
334326
groupingExpressions = groupingAttributes,
335327
aggregateExpressions = aggregateExpressions,
336328
aggregateAttributes = aggregateAttributes,
@@ -357,7 +349,7 @@ object AggUtils {
357349
val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
358350

359351
createAggregate(
360-
requiredChildDistributionOption = Some(restored.requiredChildDistribution),
352+
requiredChildDistributionExpressions = Some(groupingAttributes),
361353
groupingExpressions = groupingAttributes,
362354
aggregateExpressions = finalAggregateExpressions,
363355
aggregateAttributes = finalAggregateAttributes,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtil
2727
*/
2828
trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning {
2929
def requiredChildDistributionExpressions: Option[Seq[Expression]]
30-
def requiredChildDistributionOption: Option[Seq[Distribution]]
3130
def groupingExpressions: Seq[NamedExpression]
3231
def aggregateExpressions: Seq[AggregateExpression]
3332
def aggregateAttributes: Seq[Attribute]
@@ -91,14 +90,10 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
9190
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
9291

9392
override def requiredChildDistribution: List[Distribution] = {
94-
requiredChildDistributionOption match {
95-
case Some(dist) => dist.toList
96-
case _ =>
97-
requiredChildDistributionExpressions match {
98-
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
99-
case Some(exprs) => ClusteredDistribution(exprs) :: Nil
100-
case None => UnspecifiedDistribution :: Nil
101-
}
93+
requiredChildDistributionExpressions match {
94+
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
95+
case Some(exprs) => ClusteredDistribution(exprs) :: Nil
96+
case None => UnspecifiedDistribution :: Nil
10297
}
10398
}
10499

@@ -107,8 +102,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
107102
*/
108103
def toSortAggregate: SortAggregateExec = {
109104
SortAggregateExec(
110-
requiredChildDistributionOption, requiredChildDistributionExpressions, groupingExpressions,
111-
aggregateExpressions, aggregateAttributes, initialInputBufferOffset, resultExpressions,
112-
child)
105+
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
106+
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
113107
}
114108
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
3030
import org.apache.spark.sql.catalyst.expressions.aggregate._
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
33-
import org.apache.spark.sql.catalyst.plans.physical.Distribution
3433
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
3534
import org.apache.spark.sql.catalyst.util.truncatedString
3635
import org.apache.spark.sql.execution._
@@ -45,7 +44,6 @@ import org.apache.spark.util.Utils
4544
* Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size.
4645
*/
4746
case class HashAggregateExec(
48-
requiredChildDistributionOption: Option[Seq[Distribution]],
4947
requiredChildDistributionExpressions: Option[Seq[Expression]],
5048
groupingExpressions: Seq[NamedExpression],
5149
aggregateExpressions: Seq[AggregateExpression],

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.rdd.RDD
2323
import org.apache.spark.sql.catalyst.InternalRow
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.aggregate._
26-
import org.apache.spark.sql.catalyst.plans.physical.Distribution
2726
import org.apache.spark.sql.catalyst.util.truncatedString
2827
import org.apache.spark.sql.execution._
2928
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -59,7 +58,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
5958
* }}}
6059
*/
6160
case class ObjectHashAggregateExec(
62-
requiredChildDistributionOption: Option[Seq[Distribution]],
6361
requiredChildDistributionExpressions: Option[Seq[Expression]],
6462
groupingExpressions: Seq[NamedExpression],
6563
aggregateExpressions: Seq[AggregateExpression],

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate._
2424
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
25-
import org.apache.spark.sql.catalyst.plans.physical.Distribution
2625
import org.apache.spark.sql.catalyst.util.truncatedString
2726
import org.apache.spark.sql.execution.{AliasAwareOutputOrdering, SparkPlan}
2827
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -32,7 +31,6 @@ import org.apache.spark.sql.internal.SQLConf
3231
* Sort-based aggregate operator.
3332
*/
3433
case class SortAggregateExec(
35-
requiredChildDistributionOption: Option[Seq[Distribution]],
3634
requiredChildDistributionExpressions: Option[Seq[Expression]],
3735
groupingExpressions: Seq[NamedExpression],
3836
aggregateExpressions: Seq[AggregateExpression],

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2424
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow}
2525
import org.apache.spark.sql.catalyst.plans.logical._
26-
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, StatefulOpClusteredDistribution}
26+
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
2727
import org.apache.spark.sql.execution._
2828
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
2929
import org.apache.spark.sql.execution.streaming.state._
@@ -93,8 +93,8 @@ case class FlatMapGroupsWithStateExec(
9393
* to have the same grouping so that the data are co-lacated on the same task.
9494
*/
9595
override def requiredChildDistribution: Seq[Distribution] = {
96-
StatefulOpClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
97-
StatefulOpClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
96+
ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) ::
97+
ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) ::
9898
Nil
9999
}
100100

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3131
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
32-
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning, StatefulOpClusteredDistribution}
32+
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
3333
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
3434
import org.apache.spark.sql.errors.QueryExecutionErrors
3535
import org.apache.spark.sql.execution._
@@ -337,7 +337,7 @@ case class StateStoreRestoreExec(
337337
if (keyExpressions.isEmpty) {
338338
AllTuples :: Nil
339339
} else {
340-
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
340+
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
341341
}
342342
}
343343

@@ -496,7 +496,7 @@ case class StateStoreSaveExec(
496496
if (keyExpressions.isEmpty) {
497497
AllTuples :: Nil
498498
} else {
499-
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
499+
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
500500
}
501501
}
502502

@@ -573,8 +573,7 @@ case class SessionWindowStateStoreRestoreExec(
573573
}
574574

575575
override def requiredChildDistribution: Seq[Distribution] = {
576-
StatefulOpClusteredDistribution(keyWithoutSessionExpressions,
577-
stateInfo.map(_.numPartitions)) :: Nil
576+
ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil
578577
}
579578

580579
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
@@ -685,7 +684,7 @@ case class SessionWindowStateStoreSaveExec(
685684
override def outputPartitioning: Partitioning = child.outputPartitioning
686685

687686
override def requiredChildDistribution: Seq[Distribution] = {
688-
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
687+
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
689688
}
690689

691690
override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
@@ -743,7 +742,7 @@ case class StreamingDeduplicateExec(
743742

744743
/** Distribute by grouping attributes */
745744
override def requiredChildDistribution: Seq[Distribution] =
746-
StatefulOpClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
745+
ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil
747746

748747
override protected def doExecute(): RDD[InternalRow] = {
749748
metrics // force lazy init at driver

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
742742
assert(
743743
executedPlan.find {
744744
case WholeStageCodegenExec(
745-
HashAggregateExec(_, _, _, _, _, _, _, _: LocalTableScanExec)) => true
745+
HashAggregateExec(_, _, _, _, _, _, _: LocalTableScanExec)) => true
746746
case _ => false
747747
}.isDefined,
748748
"LocalTableScanExec should be within a WholeStageCodegen domain.")

0 commit comments

Comments
 (0)