@@ -57,8 +57,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
57
57
val data = Seq ((" James" , 1 ), (" James" , 1 ), (" Phil" , 1 ))
58
58
val aggDF = data.toDF(" name" , " values" ).groupBy(" name" ).sum(" values" )
59
59
val partAggNode = aggDF.queryExecution.executedPlan.find {
60
- case h : HashAggregateExec
61
- if AggUtils .areAggExpressionsPartial(h.aggregateExpressions) => true
60
+ case h : HashAggregateExec =>
61
+ AggUtils .areAggExpressionsPartial(h.aggregateExpressions.map(_.mode))
62
62
case _ => false
63
63
}
64
64
checkAnswer(aggDF, Seq (Row (" James" , 2 ), Row (" Phil" , 1 )))
@@ -69,6 +69,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
69
69
}
70
70
}
71
71
72
+ test(s " Partial aggregation should not happen when no Aggregate expr " ) {
73
+ withSQLConf((SQLConf .SKIP_PARTIAL_AGGREGATE_ENABLED .key, " true" )) {
74
+ val aggDF = testData2.select(sumDistinct($" a" ))
75
+ val aggNodes = aggDF.queryExecution.executedPlan.collect {
76
+ case h : HashAggregateExec => h
77
+ }
78
+ checkAnswer(aggDF, Row (6 ))
79
+ assert(aggNodes.nonEmpty)
80
+ Thread .sleep(1000000 )
81
+ assert(aggNodes.forall(_.metrics(" partialAggSkipped" ).value == 0 ))
82
+ }
83
+ }
84
+
85
+ test(s " Distinct: Partial aggregation should happen for " +
86
+ s " HashAggregate nodes performing partial Aggregate operations " ) {
87
+ withSQLConf((SQLConf .SKIP_PARTIAL_AGGREGATE_ENABLED .key, " true" )) {
88
+ val aggDF = testData2.select(sumDistinct($" a" ), sum($" b" ))
89
+ val aggNodes = aggDF.queryExecution.executedPlan.collect {
90
+ case h : HashAggregateExec => h
91
+ }
92
+ val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf [SerializeFromObjectExec ])
93
+ checkAnswer(aggDF, Row (6 , 9 ))
94
+ assert(baseNodes.size == 1 )
95
+ assert(baseNodes.head.metrics(" partialAggSkipped" ).value == testData2.count())
96
+ assert(other.forall(_.metrics(" partialAggSkipped" ).value == 0 ))
97
+ }
98
+ }
72
99
73
100
test(" Aggregate with grouping keys should be included in WholeStageCodegen" ) {
74
101
val df = spark.range(3 ).groupBy(col(" id" ) * 2 ).count().orderBy(col(" id" ) * 2 )
0 commit comments