Skip to content

Commit 4a4e35a

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-38997][SQL] DS V2 aggregate push-down supports group by expressions
### What changes were proposed in this pull request? Currently, Spark DS V2 aggregate push-down only supports group by column. But the SQL show below is very useful and common. ``` SELECT CASE WHEN 'SALARY' > 8000.00 AND 'SALARY' < 10000.00 THEN 'SALARY' ELSE 0.00 END AS key, SUM('SALARY') FROM "test"."employee" GROUP BY key ``` ### Why are the changes needed? Let DS V2 aggregate push-down supports group by expressions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests Closes #36325 from beliefer/SPARK-38997. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit ee6ea3c) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent b25276f commit 4a4e35a

File tree

11 files changed

+151
-69
lines changed

11 files changed

+151
-69
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Aggregation.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.io.Serializable;
2121

2222
import org.apache.spark.annotation.Evolving;
23-
import org.apache.spark.sql.connector.expressions.NamedReference;
23+
import org.apache.spark.sql.connector.expressions.Expression;
2424

2525
/**
2626
* Aggregation in SQL statement.
@@ -30,14 +30,14 @@
3030
@Evolving
3131
public final class Aggregation implements Serializable {
3232
private final AggregateFunc[] aggregateExpressions;
33-
private final NamedReference[] groupByColumns;
33+
private final Expression[] groupByExpressions;
3434

35-
public Aggregation(AggregateFunc[] aggregateExpressions, NamedReference[] groupByColumns) {
35+
public Aggregation(AggregateFunc[] aggregateExpressions, Expression[] groupByExpressions) {
3636
this.aggregateExpressions = aggregateExpressions;
37-
this.groupByColumns = groupByColumns;
37+
this.groupByExpressions = groupByExpressions;
3838
}
3939

4040
public AggregateFunc[] aggregateExpressions() { return aggregateExpressions; }
4141

42-
public NamedReference[] groupByColumns() { return groupByColumns; }
42+
public Expression[] groupByExpressions() { return groupByExpressions; }
4343
}

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ case class RowDataSourceScanExec(
163163
"PushedFilters" -> pushedFilters) ++
164164
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
165165
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
166-
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
166+
"PushedGroupByExpressions" -> seqToString(v.groupByExpressions.map(_.describe())))} ++
167167
topNOrLimitInfo ++
168168
pushedDownOperators.sample.map(v => "PushedSample" ->
169169
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
22+
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference}
2223
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
2324
import org.apache.spark.sql.execution.RowToColumnConverter
2425
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
@@ -93,8 +94,8 @@ object AggregatePushDownUtils {
9394
return None
9495
}
9596

96-
if (aggregation.groupByColumns.nonEmpty &&
97-
partitionNames.size != aggregation.groupByColumns.length) {
97+
if (aggregation.groupByExpressions.nonEmpty &&
98+
partitionNames.size != aggregation.groupByExpressions.length) {
9899
// If there are group by columns, we only push down if the group by columns are the same as
99100
// the partition columns. In theory, if group by columns are a subset of partition columns,
100101
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3,
@@ -106,11 +107,11 @@ object AggregatePushDownUtils {
106107
// aggregate push down simple and don't handle this complicate case for now.
107108
return None
108109
}
109-
aggregation.groupByColumns.foreach { col =>
110+
aggregation.groupByExpressions.map(extractColName).foreach { colName =>
110111
// don't push down if the group by columns are not the same as the partition columns (orders
111112
// doesn't matter because reorder can be done at data source layer)
112-
if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None
113-
finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
113+
if (colName.isEmpty || !isPartitionCol(colName.get)) return None
114+
finalSchema = finalSchema.add(getStructFieldForCol(colName.get))
114115
}
115116

116117
aggregation.aggregateExpressions.foreach {
@@ -137,7 +138,8 @@ object AggregatePushDownUtils {
137138
def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
138139
a.aggregateExpressions.sortBy(_.hashCode())
139140
.sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
140-
a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
141+
a.groupByExpressions.sortBy(_.hashCode())
142+
.sameElements(b.groupByExpressions.sortBy(_.hashCode()))
141143
}
142144

143145
/**
@@ -164,7 +166,7 @@ object AggregatePushDownUtils {
164166
def getSchemaWithoutGroupingExpression(
165167
aggSchema: StructType,
166168
aggregation: Aggregation): StructType = {
167-
val numOfGroupByColumns = aggregation.groupByColumns.length
169+
val numOfGroupByColumns = aggregation.groupByExpressions.length
168170
if (numOfGroupByColumns > 0) {
169171
new StructType(aggSchema.fields.drop(numOfGroupByColumns))
170172
} else {
@@ -179,7 +181,7 @@ object AggregatePushDownUtils {
179181
partitionSchema: StructType,
180182
aggregation: Aggregation,
181183
partitionValues: InternalRow): InternalRow = {
182-
val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head)
184+
val groupByColNames = aggregation.groupByExpressions.flatMap(extractColName)
183185
assert(groupByColNames.length == partitionSchema.length &&
184186
groupByColNames.length == partitionValues.numFields, "The number of group by columns " +
185187
s"${groupByColNames.length} should be the same as partition schema length " +
@@ -197,4 +199,9 @@ object AggregatePushDownUtils {
197199
partitionValues
198200
}
199201
}
202+
203+
private def extractColName(v2Expr: V2Expression): Option[String] = v2Expr match {
204+
case f: FieldReference if f.fieldNames.length == 1 => Some(f.fieldNames.head)
205+
case _ => None
206+
}
200207
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,14 +759,13 @@ object DataSourceStrategy
759759
protected[sql] def translateAggregation(
760760
aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {
761761

762-
def columnAsString(e: Expression): Option[FieldReference] = e match {
763-
case PushableColumnWithoutNestedColumn(name) =>
764-
Some(FieldReference.column(name).asInstanceOf[FieldReference])
762+
def translateGroupBy(e: Expression): Option[V2Expression] = e match {
763+
case PushableExpression(expr) => Some(expr)
765764
case _ => None
766765
}
767766

768767
val translatedAggregates = aggregates.flatMap(translateAggregate)
769-
val translatedGroupBys = groupBy.flatMap(columnAsString)
768+
val translatedGroupBys = groupBy.flatMap(translateGroupBy)
770769

771770
if (translatedAggregates.length != aggregates.length ||
772771
translatedGroupBys.length != groupBy.length) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ object OrcUtils extends Logging {
519519
val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy,
520520
(0 until schemaWithoutGroupBy.length).toArray)
521521
val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
522-
if (aggregation.groupByColumns.nonEmpty) {
522+
if (aggregation.groupByExpressions.nonEmpty) {
523523
val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
524524
partitionSchema, aggregation, partitionValues)
525525
new JoinedRow(reOrderedPartitionValues, resultRow)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ object ParquetUtils {
279279
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
280280
}
281281

282-
if (aggregation.groupByColumns.nonEmpty) {
282+
if (aggregation.groupByExpressions.nonEmpty) {
283283
val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
284284
partitionSchema, aggregation, partitionValues)
285285
new JoinedRow(reorderedPartitionValues, converter.currentRecord)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
183183
// scalastyle:on
184184
val newOutput = scan.readSchema().toAttributes
185185
assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
186-
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
187-
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
188-
case (_, b) => b
186+
val groupByExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
187+
val groupAttrs = normalizedGroupingExpressions.zip(newOutput).zipWithIndex.map {
188+
case ((a: Attribute, b: Attribute), _) => b.withExprId(a.exprId)
189+
case ((expr, attr), ordinal) =>
190+
if (!groupByExprToOutputOrdinal.contains(expr.canonicalized)) {
191+
groupByExprToOutputOrdinal(expr.canonicalized) = ordinal
192+
}
193+
attr
189194
}
190195
val aggOutput = newOutput.drop(groupAttrs.length)
191196
val output = groupAttrs ++ aggOutput
@@ -196,7 +201,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
196201
|Pushed Aggregate Functions:
197202
| ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
198203
|Pushed Group by:
199-
| ${pushedAggregates.get.groupByColumns.mkString(", ")}
204+
| ${pushedAggregates.get.groupByExpressions.mkString(", ")}
200205
|Output: ${output.mkString(", ")}
201206
""".stripMargin)
202207

@@ -205,14 +210,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
205210
DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
206211
if (r.supportCompletePushDown(pushedAggregates.get)) {
207212
val projectExpressions = finalResultExpressions.map { expr =>
208-
// TODO At present, only push down group by attribute is supported.
209-
// In future, more attribute conversion is extended here. e.g. GetStructField
210-
expr.transform {
213+
expr.transformDown {
211214
case agg: AggregateExpression =>
212215
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
213216
val child =
214217
addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
215218
Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
219+
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
220+
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
221+
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
216222
}
217223
}.asInstanceOf[Seq[NamedExpression]]
218224
Project(projectExpressions, scanRelation)
@@ -255,6 +261,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
255261
case other => other
256262
}
257263
agg.copy(aggregateFunction = aggFunction)
264+
case expr if groupByExprToOutputOrdinal.contains(expr.canonicalized) =>
265+
val ordinal = groupByExprToOutputOrdinal(expr.canonicalized)
266+
addCastIfNeeded(groupAttrs(ordinal), expr.dataType)
258267
}
259268
}
260269
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import scala.util.control.NonFatal
2020

2121
import org.apache.spark.internal.Logging
2222
import org.apache.spark.sql.SparkSession
23-
import org.apache.spark.sql.connector.expressions.SortOrder
23+
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
2424
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2525
import org.apache.spark.sql.connector.expressions.filter.Predicate
2626
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
@@ -70,12 +70,15 @@ case class JDBCScanBuilder(
7070

7171
private var pushedAggregateList: Array[String] = Array()
7272

73-
private var pushedGroupByCols: Option[Array[String]] = None
73+
private var pushedGroupBys: Option[Array[String]] = None
7474

7575
override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
76-
lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
76+
lazy val fieldNames = aggregation.groupByExpressions()(0) match {
77+
case field: FieldReference => field.fieldNames
78+
case _ => Array.empty[String]
79+
}
7780
jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
78-
(aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
81+
(aggregation.groupByExpressions().length == 1 && fieldNames.length == 1 &&
7982
jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
8083
}
8184

@@ -86,28 +89,26 @@ case class JDBCScanBuilder(
8689
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
8790
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false
8891

89-
val groupByCols = aggregation.groupByColumns.map { col =>
90-
if (col.fieldNames.length != 1) return false
91-
dialect.quoteIdentifier(col.fieldNames.head)
92-
}
92+
val compiledGroupBys = aggregation.groupByExpressions.flatMap(dialect.compileExpression)
93+
if (compiledGroupBys.length != aggregation.groupByExpressions.length) return false
9394

9495
// The column names here are already quoted and can be used to build sql string directly.
9596
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
9697
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
9798
// GROUP BY "DEPT", "NAME"
98-
val selectList = groupByCols ++ compiledAggs
99-
val groupByClause = if (groupByCols.isEmpty) {
99+
val selectList = compiledGroupBys ++ compiledAggs
100+
val groupByClause = if (compiledGroupBys.isEmpty) {
100101
""
101102
} else {
102-
"GROUP BY " + groupByCols.mkString(",")
103+
"GROUP BY " + compiledGroupBys.mkString(",")
103104
}
104105

105106
val aggQuery = s"SELECT ${selectList.mkString(",")} FROM ${jdbcOptions.tableOrQuery} " +
106107
s"WHERE 1=0 $groupByClause"
107108
try {
108109
finalSchema = JDBCRDD.getQueryOutputSchema(aggQuery, jdbcOptions, dialect)
109110
pushedAggregateList = selectList
110-
pushedGroupByCols = Some(groupByCols)
111+
pushedGroupBys = Some(compiledGroupBys)
111112
true
112113
} catch {
113114
case NonFatal(e) =>
@@ -173,6 +174,6 @@ case class JDBCScanBuilder(
173174
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
174175
// be used in sql string.
175176
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate,
176-
pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders)
177+
pushedAggregateList, pushedGroupBys, tableSample, pushedLimit, sortOrders)
177178
}
178179
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ case class OrcScan(
8383

8484
lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
8585
(seqToString(pushedAggregate.get.aggregateExpressions),
86-
seqToString(pushedAggregate.get.groupByColumns))
86+
seqToString(pushedAggregate.get.groupByExpressions))
8787
} else {
8888
("[]", "[]")
8989
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ case class ParquetScan(
116116

117117
lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
118118
(seqToString(pushedAggregate.get.aggregateExpressions),
119-
seqToString(pushedAggregate.get.groupByColumns))
119+
seqToString(pushedAggregate.get.groupByExpressions))
120120
} else {
121121
("[]", "[]")
122122
}

0 commit comments

Comments
 (0)