Skip to content

Commit 73de4c8

Browse files
committed
UT: Add more test
1 parent 1c0399a commit 73de4c8

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
5757
val data = Seq(("James", 1), ("James", 1), ("Phil", 1))
5858
val aggDF = data.toDF("name", "values").groupBy("name").sum("values")
5959
val partAggNode = aggDF.queryExecution.executedPlan.find {
60-
case h: HashAggregateExec
61-
if AggUtils.areAggExpressionsPartial(h.aggregateExpressions) => true
60+
case h: HashAggregateExec =>
61+
AggUtils.areAggExpressionsPartial(h.aggregateExpressions.map(_.mode))
6262
case _ => false
6363
}
6464
checkAnswer(aggDF, Seq(Row("James", 2), Row("Phil", 1)))
@@ -69,6 +69,33 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
6969
}
7070
}
7171

72+
test(s"Partial aggregation should not happen when no Aggregate expr" ) {
73+
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
74+
val aggDF = testData2.select(sumDistinct($"a"))
75+
val aggNodes = aggDF.queryExecution.executedPlan.collect {
76+
case h: HashAggregateExec => h
77+
}
78+
checkAnswer(aggDF, Row(6))
79+
assert(aggNodes.nonEmpty)
80+
Thread.sleep(1000000)
81+
assert(aggNodes.forall(_.metrics("partialAggSkipped").value == 0))
82+
}
83+
}
84+
85+
test(s"Distinct: Partial aggregation should happen for" +
86+
s" HashAggregate nodes performing partial Aggregate operations " ) {
87+
withSQLConf((SQLConf.SKIP_PARTIAL_AGGREGATE_ENABLED.key, "true")) {
88+
val aggDF = testData2.select(sumDistinct($"a"), sum($"b"))
89+
val aggNodes = aggDF.queryExecution.executedPlan.collect {
90+
case h: HashAggregateExec => h
91+
}
92+
val (baseNodes, other) = aggNodes.partition(_.child.isInstanceOf[SerializeFromObjectExec])
93+
checkAnswer(aggDF, Row(6, 9))
94+
assert(baseNodes.size == 1 )
95+
assert(baseNodes.head.metrics("partialAggSkipped").value == testData2.count())
96+
assert(other.forall(_.metrics("partialAggSkipped").value == 0))
97+
}
98+
}
7299

73100
test("Aggregate with grouping keys should be included in WholeStageCodegen") {
74101
val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2)

0 commit comments

Comments
 (0)