Skip to content

Commit 782a0a8

Browse files
Huaxin GaoHuaxin Gao
authored andcommitted
address comments
1 parent ef4bab9 commit 782a0a8

File tree

7 files changed

+138
-82
lines changed

7 files changed

+138
-82
lines changed

docs/sql-data-sources-jdbc.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ the following case-insensitive options:
211211
Specifies kerberos principal name for the JDBC client. If both <code>keytab</code> and <code>principal</code> are defined then Spark tries to do kerberos authentication.
212212
</td>
213213
</tr>
214+
215+
<tr>
216+
<td><code>pushDownAggregate</code></td>
217+
<td>
218+
The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will NOT push down aggregates to the JDBC data source. Otherwise, if set to true, aggregate will be pushed down to the JDBC data source and thus aggregates will be handled by data source instead of Spark. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates are pushed down if and only if all the aggregates and the related filters can be pushed down.
219+
</td>
220+
</tr>
214221
</table>
215222

216223
Note that kerberos authentication with keytab is not always supported by the JDBC driver.<br>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21-
import org.apache.spark.sql.catalyst.dsl.expressions._
2220
import org.apache.spark.sql.catalyst.expressions._
23-
import org.apache.spark.sql.internal.SQLConf
24-
import org.apache.spark.sql.types._
2521

2622
// scalastyle:off line.size.limit
2723
@ExpressionDescription(
@@ -44,62 +40,7 @@ import org.apache.spark.sql.types._
4440
group = "agg_funcs",
4541
since = "1.0.0")
4642
// scalastyle:on line.size.limit
47-
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
48-
49-
override def nullable: Boolean = false
50-
51-
// Return data type.
52-
override def dataType: DataType = LongType
53-
54-
override def checkInputDataTypes(): TypeCheckResult = {
55-
if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
56-
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
57-
s"If you have to call the function $prettyName without arguments, set the legacy " +
58-
s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
59-
} else {
60-
TypeCheckResult.TypeCheckSuccess
61-
}
62-
}
63-
64-
protected lazy val count = AttributeReference("count", LongType, nullable = false)()
65-
66-
override lazy val aggBufferAttributes = count :: Nil
67-
68-
override lazy val initialValues = Seq(
69-
/* count = */ Literal(0L)
70-
)
71-
72-
override lazy val mergeExpressions = Seq(
73-
/* count = */ count.left + count.right
74-
)
75-
76-
override lazy val evaluateExpression = count
77-
78-
override def defaultResult: Option[Literal] = Option(Literal(0L))
79-
80-
private[sql] var pushDown: Boolean = false
81-
82-
override lazy val updateExpressions = {
83-
if (!pushDown) {
84-
val nullableChildren = children.filter(_.nullable)
85-
if (nullableChildren.isEmpty) {
86-
Seq(
87-
/* count = */ count + 1L
88-
)
89-
} else {
90-
Seq(
91-
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
92-
)
93-
}
94-
} else {
95-
Seq(
96-
// if count is pushed down to Data Source layer, add the count result retrieved from
97-
// Data Source
98-
/* count = */ count + children.head
99-
)
100-
}
101-
}
102-
}
43+
case class Count(children: Seq[Expression]) extends CountBase(children)
10344

10445
object Count {
10546
def apply(child: Expression): Count = Count(child :: Nil)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
21+
import org.apache.spark.sql.catalyst.dsl.expressions._
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.internal.SQLConf
24+
import org.apache.spark.sql.types._
25+
26+
abstract class CountBase(children: Seq[Expression]) extends DeclarativeAggregate {
27+
28+
override def nullable: Boolean = false
29+
30+
// Return data type.
31+
override def dataType: DataType = LongType
32+
33+
override def checkInputDataTypes(): TypeCheckResult = {
34+
if (children.isEmpty && !SQLConf.get.getConf(SQLConf.ALLOW_PARAMETERLESS_COUNT)) {
35+
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least one argument. " +
36+
s"If you have to call the function $prettyName without arguments, set the legacy " +
37+
s"configuration `${SQLConf.ALLOW_PARAMETERLESS_COUNT.key}` as true")
38+
} else {
39+
TypeCheckResult.TypeCheckSuccess
40+
}
41+
}
42+
43+
protected lazy val count = AttributeReference("count", LongType, nullable = false)()
44+
45+
override lazy val aggBufferAttributes = count :: Nil
46+
47+
override lazy val initialValues = Seq(
48+
/* count = */ Literal(0L)
49+
)
50+
51+
override lazy val mergeExpressions = Seq(
52+
/* count = */ count.left + count.right
53+
)
54+
55+
override lazy val evaluateExpression = count
56+
57+
override def defaultResult: Option[Literal] = Option(Literal(0L))
58+
59+
override lazy val updateExpressions = {
60+
val nullableChildren = children.filter(_.nullable)
61+
if (nullableChildren.isEmpty) {
62+
Seq(
63+
/* count = */ count + 1L
64+
)
65+
} else {
66+
Seq(
67+
/* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L)
68+
)
69+
}
70+
}
71+
}
72+
73+
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.types.LongType
23+
24+
case class PushDownCount(children: Seq[Expression]) extends CountBase(children) {
25+
26+
override protected lazy val count =
27+
AttributeReference("PushDownCount", LongType, nullable = false)()
28+
29+
override lazy val updateExpressions = {
30+
Seq(
31+
// if count is pushed down to Data Source layer, add the count result retrieved from
32+
// Data Source
33+
/* count = */ count + children.head
34+
)
35+
}
36+
}
37+
38+
object PushDownCount {
39+
def apply(child: Expression): PushDownCount =
40+
PushDownCount(child :: Nil)
41+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,6 @@ object DataSourceStrategy
721721
}
722722

723723
protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
724-
725724
aggregates.aggregateFunction match {
726725
case min: aggregate.Min =>
727726
val colName = columnAsString(min.child)
@@ -737,8 +736,7 @@ object DataSourceStrategy
737736
if (colName.nonEmpty) Some(Sum(colName, sum.dataType, aggregates.isDistinct)) else None
738737
case count: aggregate.Count =>
739738
val columnName = count.children.head match {
740-
case Literal(_, _) =>
741-
"1"
739+
case Literal(_, _) => "1"
742740
case _ => columnAsString(count.children.head)
743741
}
744742
if (columnName.nonEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -257,32 +257,30 @@ private[jdbc] class JDBCRDD(
257257
*/
258258
override def getPartitions: Array[Partition] = partitions
259259

260-
private var updatedSchema: StructType = new StructType()
260+
private val (updatedSchema, updatedCol): (StructType, Array[String]) =
261+
if (aggregation.aggregateExpressions.isEmpty) {
262+
(schema, columns)
263+
} else {
264+
getAggregateSchemaAndCol
265+
}
261266

262267
/**
263268
* `columns`, but as a String suitable for injection into a SQL query.
264269
*/
265270
private val columnList: String = {
266271
val sb = new StringBuilder()
267-
if(aggregation.aggregateExpressions.isEmpty) {
268-
updatedSchema = schema
269-
columns.foreach(x => sb.append(",").append(x))
270-
} else {
271-
val (compiledAgg, aggDataType) =
272-
JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url))
273-
getAggregateColumnsList(sb, compiledAgg, aggDataType)
274-
}
275-
if (sb.length == 0) "1" else sb.substring(1)
272+
updatedCol.foreach(x => sb.append(",").append(x))
273+
if (sb.isEmpty) "1" else sb.substring(1)
276274
}
277275

278-
/*
276+
/**
279277
* Build the column lists for Aggregates push down:
280278
* each of the Aggregates + groupBy columns
281279
*/
282-
private def getAggregateColumnsList(
283-
sb: StringBuilder,
284-
compiledAgg: Array[String],
285-
aggDataType: Array[DataType]): Unit = {
280+
private def getAggregateSchemaAndCol(): (StructType, Array[String]) = {
281+
var updatedSchema: StructType = new StructType()
282+
val (compiledAgg, aggDataType) =
283+
JDBCRDD.compileAggregates(aggregation.aggregateExpressions, JdbcDialects.get(url))
286284
val colDataTypeMap: Map[String, StructField] = columns.zip(schema.fields).toMap
287285
val newColsBuilder = ArrayBuilder.make[String]
288286
for ((col, dataType) <- compiledAgg.zip(aggDataType)) {
@@ -294,7 +292,7 @@ private[jdbc] class JDBCRDD(
294292
newColsBuilder += quotedGroupBy
295293
updatedSchema = updatedSchema.add(colDataTypeMap.get(quotedGroupBy).get)
296294
}
297-
sb.append(", ").append(newColsBuilder.result.mkString(", "))
295+
(updatedSchema, newColsBuilder.result)
298296
}
299297

300298
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with AliasHelper {
125125
} else if (agg.aggregateFunction.isInstanceOf[aggregate.Sum]) {
126126
aggregate.Sum(aggOutput(i - 1))
127127
} else if (agg.aggregateFunction.isInstanceOf[aggregate.Count]) {
128-
val count = aggregate.Count(aggOutput(i - 1))
129-
count.pushDown = true
130-
count
128+
aggregate.PushDownCount(aggOutput(i - 1))
131129
} else {
132130
agg.aggregateFunction
133131
}

0 commit comments

Comments
 (0)