Skip to content

Commit f327dad

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-38533][SQL] DS V2 aggregate push-down supports project with alias
### What changes were proposed in this pull request? Currently, Spark DS V2 aggregate push-down doesn't supports project with alias. Refer https://github.com/apache/spark/blob/c91c2e9afec0d5d5bbbd2e155057fe409c5bb928/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala#L96 This PR let it works good with alias. **The first example:** the origin plan show below: ``` Aggregate [DEPT#0], [DEPT#0, sum(mySalary#8) AS total#14] +- Project [DEPT#0, SALARY#2 AS mySalary#8] +- ScanBuilderHolder [DEPT#0, NAME#1, SALARY#2, BONUS#3], RelationV2[DEPT#0, NAME#1, SALARY#2, BONUS#3] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession77978658,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions5f8da82) ``` If we can complete push down the aggregate, then the plan will be: ``` Project [DEPT#0, SUM(SALARY)#18 AS sum(SALARY#2)#13 AS total#14] +- RelationV2[DEPT#0, SUM(SALARY)#18] test.employee ``` If we can partial push down the aggregate, then the plan will be: ``` Aggregate [DEPT#0], [DEPT#0, sum(cast(SUM(SALARY)#18 as decimal(20,2))) AS total#14] +- RelationV2[DEPT#0, SUM(SALARY)#18] test.employee ``` **The second example:** the origin plan show below: ``` Aggregate [myDept#33], [myDept#33, sum(mySalary#34) AS total#40] +- Project [DEPT#25 AS myDept#33, SALARY#27 AS mySalary#34] +- ScanBuilderHolder [DEPT#25, NAME#26, SALARY#27, BONUS#28], RelationV2[DEPT#25, NAME#26, SALARY#27, BONUS#28] test.employee, JDBCScanBuilder(org.apache.spark.sql.test.TestSparkSession25c4f621,StructType(StructField(DEPT,IntegerType,true),StructField(NAME,StringType,true),StructField(SALARY,DecimalType(20,2),true),StructField(BONUS,DoubleType,true)),org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions345d641e) ``` If we can complete push down the aggregate, then the plan will be: ``` Project [DEPT#25 AS myDept#33, SUM(SALARY)#44 AS sum(SALARY#27)#39 AS total#40] +- RelationV2[DEPT#25, SUM(SALARY)#44] test.employee ``` If we can partial push down the aggregate, then the plan will be: ``` Aggregate [myDept#33], [DEPT#25 AS myDept#33, sum(cast(SUM(SALARY)#56 as decimal(20,2))) AS total#52] +- RelationV2[DEPT#25, SUM(SALARY)#56] test.employee ``` ### Why are the changes needed? Alias is more useful. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could see DS V2 aggregate push-down supports project with alias. ### How was this patch tested? New tests. Closes #35932 from beliefer/SPARK-38533_new. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent ac9ae98 commit f327dad

File tree

3 files changed

+97
-15
lines changed

3 files changed

+97
-15
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

+14-8
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Divide, DivideDTInterval, DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
2323
import org.apache.spark.sql.catalyst.expressions.aggregate
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25+
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
2526
import org.apache.spark.sql.catalyst.planning.ScanOperation
2627
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
2728
import org.apache.spark.sql.catalyst.rules.Rule
@@ -34,7 +35,7 @@ import org.apache.spark.sql.sources
3435
import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType, StructType, YearMonthIntervalType}
3536
import org.apache.spark.sql.util.SchemaUtils._
3637

37-
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
38+
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper {
3839
import DataSourceV2Implicits._
3940

4041
def apply(plan: LogicalPlan): LogicalPlan = {
@@ -95,22 +96,27 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
9596
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
9697
child match {
9798
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
98-
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
99+
if filters.isEmpty && CollapseProject.canCollapseExpressions(
100+
resultExpressions, project, alwaysInline = true) =>
99101
sHolder.builder match {
100102
case r: SupportsPushDownAggregates =>
103+
val aliasMap = getAliasMap(project)
104+
val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap))
105+
val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap))
106+
101107
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
102-
val aggregates = collectAggregates(resultExpressions, aggExprToOutputOrdinal)
108+
val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
103109
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
104110
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
105111
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
106-
groupingExpressions, sHolder.relation.output)
112+
actualGroupExprs, sHolder.relation.output)
107113
val translatedAggregates = DataSourceStrategy.translateAggregation(
108114
normalizedAggregates, normalizedGroupingExpressions)
109115
val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
110116
if (translatedAggregates.isEmpty ||
111117
r.supportCompletePushDown(translatedAggregates.get) ||
112118
translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
113-
(resultExpressions, aggregates, translatedAggregates)
119+
(actualResultExprs, aggregates, translatedAggregates)
114120
} else {
115121
// scalastyle:off
116122
// The data source doesn't support the complete push-down of this aggregation.
@@ -127,7 +133,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
127133
// Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
128134
// +- ScanOperation[...]
129135
// scalastyle:on
130-
val newResultExpressions = resultExpressions.map { expr =>
136+
val newResultExpressions = actualResultExprs.map { expr =>
131137
expr.transform {
132138
case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
133139
val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
@@ -206,7 +212,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
206212
val scanRelation =
207213
DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
208214
if (r.supportCompletePushDown(pushedAggregates.get)) {
209-
val projectExpressions = resultExpressions.map { expr =>
215+
val projectExpressions = finalResultExpressions.map { expr =>
210216
// TODO At present, only push down group by attribute is supported.
211217
// In future, more attribute conversion is extended here. e.g. GetStructField
212218
expr.transform {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ trait FileSourceAggregatePushDownSuite
184184
}
185185
}
186186

187-
test("aggregate over alias not push down") {
187+
test("aggregate over alias push down") {
188188
val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
189189
(9, "mno", 7), (2, null, 6))
190190
withDataSourceTable(data, "t") {
@@ -194,7 +194,7 @@ trait FileSourceAggregatePushDownSuite
194194
query.queryExecution.optimizedPlan.collect {
195195
case _: DataSourceV2ScanRelation =>
196196
val expected_plan_fragment =
197-
"PushedAggregation: []" // aggregate alias not pushed down
197+
"PushedAggregation: [MIN(_1)]"
198198
checkKeywordsExistsInExplain(query, expected_plan_fragment)
199199
}
200200
checkAnswer(query, Seq(Row(-2)))

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

+81-5
Original file line numberDiff line numberDiff line change
@@ -974,15 +974,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
974974
checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
975975
}
976976

977-
test("scan with aggregate push-down: aggregate over alias NOT push down") {
977+
test("scan with aggregate push-down: aggregate over alias push down") {
978978
val cols = Seq("a", "b", "c", "d", "e")
979979
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
980980
val df2 = df1.groupBy().sum("c")
981-
checkAggregateRemoved(df2, false)
981+
checkAggregateRemoved(df2)
982982
df2.queryExecution.optimizedPlan.collect {
983-
case relation: DataSourceV2ScanRelation => relation.scan match {
984-
case v1: V1ScanWrapper =>
985-
assert(v1.pushedDownOperators.aggregation.isEmpty)
983+
case relation: DataSourceV2ScanRelation =>
984+
val expectedPlanFragment =
985+
"PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []"
986+
checkKeywordsExistsInExplain(df2, expectedPlanFragment)
987+
relation.scan match {
988+
case v1: V1ScanWrapper =>
989+
assert(v1.pushedDownOperators.aggregation.nonEmpty)
986990
}
987991
}
988992
checkAnswer(df2, Seq(Row(53000.00)))
@@ -1228,4 +1232,76 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
12281232
|ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin)
12291233
checkAnswer(df, Seq.empty[Row])
12301234
}
1235+
1236+
test("scan with aggregate push-down: complete push-down aggregate with alias") {
1237+
val df = spark.table("h2.test.employee")
1238+
.select($"DEPT", $"SALARY".as("mySalary"))
1239+
.groupBy($"DEPT")
1240+
.agg(sum($"mySalary").as("total"))
1241+
.filter($"total" > 1000)
1242+
checkAggregateRemoved(df)
1243+
df.queryExecution.optimizedPlan.collect {
1244+
case _: DataSourceV2ScanRelation =>
1245+
val expectedPlanFragment =
1246+
"PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]"
1247+
checkKeywordsExistsInExplain(df, expectedPlanFragment)
1248+
}
1249+
checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
1250+
1251+
val df2 = spark.table("h2.test.employee")
1252+
.select($"DEPT".as("myDept"), $"SALARY".as("mySalary"))
1253+
.groupBy($"myDept")
1254+
.agg(sum($"mySalary").as("total"))
1255+
.filter($"total" > 1000)
1256+
checkAggregateRemoved(df2)
1257+
df2.queryExecution.optimizedPlan.collect {
1258+
case _: DataSourceV2ScanRelation =>
1259+
val expectedPlanFragment =
1260+
"PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]"
1261+
checkKeywordsExistsInExplain(df2, expectedPlanFragment)
1262+
}
1263+
checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
1264+
}
1265+
1266+
test("scan with aggregate push-down: partial push-down aggregate with alias") {
1267+
val df = spark.read
1268+
.option("partitionColumn", "DEPT")
1269+
.option("lowerBound", "0")
1270+
.option("upperBound", "2")
1271+
.option("numPartitions", "2")
1272+
.table("h2.test.employee")
1273+
.select($"NAME", $"SALARY".as("mySalary"))
1274+
.groupBy($"NAME")
1275+
.agg(sum($"mySalary").as("total"))
1276+
.filter($"total" > 1000)
1277+
checkAggregateRemoved(df, false)
1278+
df.queryExecution.optimizedPlan.collect {
1279+
case _: DataSourceV2ScanRelation =>
1280+
val expectedPlanFragment =
1281+
"PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]"
1282+
checkKeywordsExistsInExplain(df, expectedPlanFragment)
1283+
}
1284+
checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
1285+
Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
1286+
1287+
val df2 = spark.read
1288+
.option("partitionColumn", "DEPT")
1289+
.option("lowerBound", "0")
1290+
.option("upperBound", "2")
1291+
.option("numPartitions", "2")
1292+
.table("h2.test.employee")
1293+
.select($"NAME".as("myName"), $"SALARY".as("mySalary"))
1294+
.groupBy($"myName")
1295+
.agg(sum($"mySalary").as("total"))
1296+
.filter($"total" > 1000)
1297+
checkAggregateRemoved(df2, false)
1298+
df2.queryExecution.optimizedPlan.collect {
1299+
case _: DataSourceV2ScanRelation =>
1300+
val expectedPlanFragment =
1301+
"PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]"
1302+
checkKeywordsExistsInExplain(df2, expectedPlanFragment)
1303+
}
1304+
checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
1305+
Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
1306+
}
12311307
}

0 commit comments

Comments
 (0)