Skip to content

Commit decd393

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-39135][SQL] DS V2 aggregate partial push-down should supports group by without aggregate functions
### What changes were proposed in this pull request? Currently, the SQL show below not supported by DS V2 aggregate partial push-down. `select key from tab group by key` ### Why are the changes needed? Make DS V2 aggregate partial push-down supports group by without aggregate functions. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests Closes #36492 from beliefer/SPARK-39135. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 58c3613 commit decd393

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
294294
private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
295295
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
296296
// If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down.
297-
agg.aggregateExpressions().exists {
297+
agg.aggregateExpressions().isEmpty || agg.aggregateExpressions().exists {
298298
case sum: Sum => !sum.isDistinct
299299
case count: Count => !count.isDistinct
300300
case avg: Avg => !avg.isDistinct

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,57 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
727727
checkAnswer(df, Seq(Row(5)))
728728
}
729729

730+
test("scan with aggregate push-down: GROUP BY without aggregate functions") {
731+
val df = sql("select name FROM h2.test.employee GROUP BY name")
732+
checkAggregateRemoved(df)
733+
checkPushedInfo(df,
734+
"PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],")
735+
checkAnswer(df, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen")))
736+
737+
val df2 = spark.read
738+
.option("partitionColumn", "dept")
739+
.option("lowerBound", "0")
740+
.option("upperBound", "2")
741+
.option("numPartitions", "2")
742+
.table("h2.test.employee")
743+
.groupBy($"name")
744+
.agg(Map.empty[String, String])
745+
checkAggregateRemoved(df2, false)
746+
checkPushedInfo(df2,
747+
"PushedAggregates: [], PushedFilters: [], PushedGroupByExpressions: [NAME],")
748+
checkAnswer(df2, Seq(Row("alex"), Row("amy"), Row("cathy"), Row("david"), Row("jen")))
749+
750+
val df3 = sql("SELECT CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END as" +
751+
" key FROM h2.test.employee GROUP BY key")
752+
checkAggregateRemoved(df3)
753+
checkPushedInfo(df3,
754+
"""
755+
|PushedAggregates: [],
756+
|PushedFilters: [],
757+
|PushedGroupByExpressions:
758+
|[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END],
759+
|""".stripMargin.replaceAll("\n", " "))
760+
checkAnswer(df3, Seq(Row(0), Row(9000)))
761+
762+
val df4 = spark.read
763+
.option("partitionColumn", "dept")
764+
.option("lowerBound", "0")
765+
.option("upperBound", "2")
766+
.option("numPartitions", "2")
767+
.table("h2.test.employee")
768+
.groupBy(when(($"SALARY" > 8000).and($"SALARY" < 10000), $"SALARY").otherwise(0).as("key"))
769+
.agg(Map.empty[String, String])
770+
checkAggregateRemoved(df4, false)
771+
checkPushedInfo(df4,
772+
"""
773+
|PushedAggregates: [],
774+
|PushedFilters: [],
775+
|PushedGroupByExpressions:
776+
|[CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00) THEN SALARY ELSE 0.00 END],
777+
|""".stripMargin.replaceAll("\n", " "))
778+
checkAnswer(df4, Seq(Row(0), Row(9000)))
779+
}
780+
730781
test("scan with aggregate push-down: COUNT(col)") {
731782
val df = sql("select COUNT(DEPT) FROM h2.test.employee")
732783
checkAggregateRemoved(df)

0 commit comments

Comments
 (0)