Skip to content

Commit 060a2b8

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-42131][SQL] Extract the function that construct the select statement for JDBC dialect
### What changes were proposed in this pull request? Currently, JDBCRDD uses fixed format for SELECT statement. ``` val sqlText = options.prepareQuery + s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause $myOffsetClause" ``` But some databases have different syntax. For example, MS SQL Server uses keyword TOP to describe LIMIT clause or Top N. The LIMIT clause of MS SQL Server show below. ``` SELECT TOP(1) Model, Color, Price FROM dbo.Cars WHERE Color = 'blue' ``` The Top N of MS SQL Server show below. ``` SELECT TOP(1) Model, Color, Price FROM dbo.Cars WHERE Color = 'blue' ORDER BY Price ASC ``` This PR lets JDBC dialect could define their own syntax. ### Why are the changes needed? Extract the function that construct the select statement for JDBC dialect. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? N/A Closes apache#39667 from beliefer/SPARK-42131. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent d4e5df8 commit 060a2b8

File tree

11 files changed

+311
-65
lines changed

11 files changed

+311
-65
lines changed

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
6262
.set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName)
6363
.set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort))
6464
.set("spark.sql.catalog.db2.pushDownAggregate", "true")
65+
.set("spark.sql.catalog.db2.pushDownLimit", "true")
6566

6667
override def tablePreparation(connection: Connection): Unit = {
6768
connection.prepareStatement(

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
5959
.set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName)
6060
.set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort))
6161
.set("spark.sql.catalog.mssql.pushDownAggregate", "true")
62+
.set("spark.sql.catalog.mssql.pushDownLimit", "true")
6263

6364
override val connectionTimeout = timeout(7.minutes)
6465

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
5555
.set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName)
5656
.set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort))
5757
.set("spark.sql.catalog.mysql.pushDownAggregate", "true")
58+
.set("spark.sql.catalog.mysql.pushDownLimit", "true")
5859

5960
override val connectionTimeout = timeout(7.minutes)
6061

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
7676
.set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName)
7777
.set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort))
7878
.set("spark.sql.catalog.oracle.pushDownAggregate", "true")
79+
.set("spark.sql.catalog.oracle.pushDownLimit", "true")
7980

8081
override val connectionTimeout = timeout(7.minutes)
8182

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ import org.apache.logging.log4j.Level
2121

2222
import org.apache.spark.sql.{AnalysisException, DataFrame}
2323
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException, UnresolvedAttribute}
24-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample}
24+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample, Sort}
2525
import org.apache.spark.sql.catalyst.util.quoteIdentifier
2626
import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog}
2727
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
28+
import org.apache.spark.sql.connector.expressions.NullOrdering
2829
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc
2930
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
3031
import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite
@@ -402,7 +403,49 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
402403
}
403404
}
404405

405-
protected def checkAggregateRemoved(df: DataFrame): Unit = {
406+
private def checkSortRemoved(df: DataFrame): Unit = {
407+
val sorts = df.queryExecution.optimizedPlan.collect {
408+
case s: Sort => s
409+
}
410+
assert(sorts.isEmpty)
411+
}
412+
413+
test("simple scan with LIMIT") {
414+
val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
415+
s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1")
416+
assert(limitPushed(df, 1))
417+
val rows = df.collect()
418+
assert(rows.length === 1)
419+
assert(rows(0).getString(0) === "amy")
420+
assert(rows(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
421+
assert(rows(0).getDouble(2) === 1000d)
422+
}
423+
424+
test("simple scan with top N") {
425+
Seq(NullOrdering.values()).flatten.foreach { nullOrdering =>
426+
val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
427+
s"${caseConvert("employee")} WHERE dept > 0 ORDER BY salary $nullOrdering LIMIT 1")
428+
assert(limitPushed(df1, 1))
429+
checkSortRemoved(df1)
430+
val rows1 = df1.collect()
431+
assert(rows1.length === 1)
432+
assert(rows1(0).getString(0) === "cathy")
433+
assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("9000.00"))
434+
assert(rows1(0).getDouble(2) === 1200d)
435+
436+
val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
437+
s"${caseConvert("employee")} WHERE dept > 0 ORDER BY bonus DESC $nullOrdering LIMIT 1")
438+
assert(limitPushed(df2, 1))
439+
checkSortRemoved(df2)
440+
val rows2 = df2.collect()
441+
assert(rows2.length === 1)
442+
assert(rows2(0).getString(0) === "david")
443+
assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
444+
assert(rows2(0).getDouble(2) === 1300d)
445+
}
446+
}
447+
448+
private def checkAggregateRemoved(df: DataFrame): Unit = {
406449
val aggregates = df.queryExecution.optimizedPlan.collect {
407450
case agg: Aggregate => agg
408451
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -178,54 +178,6 @@ private[jdbc] class JDBCRDD(
178178
*/
179179
override def getPartitions: Array[Partition] = partitions
180180

181-
/**
182-
* `columns`, but as a String suitable for injection into a SQL query.
183-
*/
184-
private val columnList: String = if (columns.isEmpty) "1" else columns.mkString(",")
185-
186-
/**
187-
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
188-
*/
189-
private val filterWhereClause: String = {
190-
val dialect = JdbcDialects.get(url)
191-
predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ")
192-
}
193-
194-
/**
195-
* A WHERE clause representing both `filters`, if any, and the current partition.
196-
*/
197-
private def getWhereClause(part: JDBCPartition): String = {
198-
if (part.whereClause != null && filterWhereClause.length > 0) {
199-
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
200-
} else if (part.whereClause != null) {
201-
"WHERE " + part.whereClause
202-
} else if (filterWhereClause.length > 0) {
203-
"WHERE " + filterWhereClause
204-
} else {
205-
""
206-
}
207-
}
208-
209-
/**
210-
* A GROUP BY clause representing pushed-down grouping columns.
211-
*/
212-
private def getGroupByClause: String = {
213-
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
214-
// The GROUP BY columns should already be quoted by the caller side.
215-
s"GROUP BY ${groupByColumns.get.mkString(", ")}"
216-
} else {
217-
""
218-
}
219-
}
220-
221-
private def getOrderByClause: String = {
222-
if (sortOrders.nonEmpty) {
223-
s" ORDER BY ${sortOrders.mkString(", ")}"
224-
} else {
225-
""
226-
}
227-
}
228-
229181
/**
230182
* Runs the SQL query against the JDBC driver.
231183
*
@@ -299,20 +251,23 @@ private[jdbc] class JDBCRDD(
299251
// fully-qualified table name in the SELECT statement. I don't know how to
300252
// talk about a table in a completely portable way.
301253

302-
val myWhereClause = getWhereClause(part)
254+
var builder = dialect
255+
.getJdbcSQLQueryBuilder(options)
256+
.withColumns(columns)
257+
.withPredicates(predicates, part)
258+
.withSortOrders(sortOrders)
259+
.withLimit(limit)
260+
.withOffset(offset)
303261

304-
val myTableSampleClause: String = if (sample.nonEmpty) {
305-
JdbcDialects.get(url).getTableSample(sample.get)
306-
} else {
307-
""
262+
groupByColumns.foreach { groupByKeys =>
263+
builder = builder.withGroupByColumns(groupByKeys)
308264
}
309265

310-
val myLimitClause: String = dialect.getLimitClause(limit)
311-
val myOffsetClause: String = dialect.getOffsetClause(offset)
266+
sample.foreach { tableSampleInfo =>
267+
builder = builder.withTableSample(tableSampleInfo)
268+
}
312269

313-
val sqlText = options.prepareQuery +
314-
s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
315-
s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause $myOffsetClause"
270+
val sqlText = builder.build()
316271
stmt = conn.prepareStatement(sqlText,
317272
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
318273
stmt.setFetchSize(options.fetchSize)

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,19 +538,25 @@ abstract class JdbcDialect extends Serializable with Logging {
538538
}
539539

540540
/**
541-
* returns the LIMIT clause for the SELECT statement
541+
* Returns the LIMIT clause for the SELECT statement
542542
*/
543543
def getLimitClause(limit: Integer): String = {
544544
if (limit > 0 ) s"LIMIT $limit" else ""
545545
}
546546

547547
/**
548-
* returns the OFFSET clause for the SELECT statement
548+
* Returns the OFFSET clause for the SELECT statement
549549
*/
550550
def getOffsetClause(offset: Integer): String = {
551551
if (offset > 0 ) s"OFFSET $offset" else ""
552552
}
553553

554+
/**
555+
* Returns the SQL builder for the SELECT statement.
556+
*/
557+
def getJdbcSQLQueryBuilder(options: JDBCOptions): JdbcSQLQueryBuilder =
558+
new JdbcSQLQueryBuilder(this, options)
559+
554560
def supportsTableSample: Boolean = false
555561

556562
def getTableSample(sample: TableSampleInfo): String =
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.jdbc
19+
20+
import org.apache.spark.sql.connector.expressions.filter.Predicate
21+
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition}
22+
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
23+
24+
/**
25+
* The builder to build a single SELECT query.
26+
*
27+
* Note: All the `withXXX` methods will be invoked at most once. The invocation order does not
28+
* matter, as all these clauses follow the natural SQL order: sample the table first, then filter,
29+
* then group by, then sort, then offset, then limit.
30+
*
31+
* @since 3.5.0
32+
*/
33+
class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) {
34+
35+
/**
36+
* `columns`, but as a String suitable for injection into a SQL query.
37+
*/
38+
protected var columnList: String = "1"
39+
40+
/**
41+
* A WHERE clause representing both `filters`, if any, and the current partition.
42+
*/
43+
protected var whereClause: String = ""
44+
45+
/**
46+
* A GROUP BY clause representing pushed-down grouping columns.
47+
*/
48+
protected var groupByClause: String = ""
49+
50+
/**
51+
* A ORDER BY clause representing pushed-down sort of top n.
52+
*/
53+
protected var orderByClause: String = ""
54+
55+
/**
56+
* A LIMIT value representing pushed-down limit.
57+
*/
58+
protected var limit: Int = -1
59+
60+
/**
61+
* A OFFSET value representing pushed-down offset.
62+
*/
63+
protected var offset: Int = -1
64+
65+
/**
66+
* A table sample clause representing pushed-down table sample.
67+
*/
68+
protected var tableSampleClause: String = ""
69+
70+
/**
71+
* The columns names that following dialect's SQL syntax.
72+
* e.g. The column name is the raw name or quoted name.
73+
*/
74+
def withColumns(columns: Array[String]): JdbcSQLQueryBuilder = {
75+
if (columns.nonEmpty) {
76+
columnList = columns.mkString(",")
77+
}
78+
this
79+
}
80+
81+
/**
82+
* Constructs the WHERE clause that following dialect's SQL syntax.
83+
*/
84+
def withPredicates(predicates: Array[Predicate], part: JDBCPartition): JdbcSQLQueryBuilder = {
85+
// `filters`, but as a WHERE clause suitable for injection into a SQL query.
86+
val filterWhereClause: String = {
87+
predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ")
88+
}
89+
90+
// A WHERE clause representing both `filters`, if any, and the current partition.
91+
whereClause = if (part.whereClause != null && filterWhereClause.length > 0) {
92+
"WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
93+
} else if (part.whereClause != null) {
94+
"WHERE " + part.whereClause
95+
} else if (filterWhereClause.length > 0) {
96+
"WHERE " + filterWhereClause
97+
} else {
98+
""
99+
}
100+
101+
this
102+
}
103+
104+
/**
105+
* Constructs the GROUP BY clause that following dialect's SQL syntax.
106+
*/
107+
def withGroupByColumns(groupByColumns: Array[String]): JdbcSQLQueryBuilder = {
108+
if (groupByColumns.nonEmpty) {
109+
// The GROUP BY columns should already be quoted by the caller side.
110+
groupByClause = s"GROUP BY ${groupByColumns.mkString(", ")}"
111+
}
112+
113+
this
114+
}
115+
116+
/**
117+
* Constructs the ORDER BY clause that following dialect's SQL syntax.
118+
*/
119+
def withSortOrders(sortOrders: Array[String]): JdbcSQLQueryBuilder = {
120+
if (sortOrders.nonEmpty) {
121+
orderByClause = s" ORDER BY ${sortOrders.mkString(", ")}"
122+
}
123+
124+
this
125+
}
126+
127+
/**
128+
* Saves the limit value used to construct LIMIT clause.
129+
*/
130+
def withLimit(limit: Int): JdbcSQLQueryBuilder = {
131+
this.limit = limit
132+
133+
this
134+
}
135+
136+
/**
137+
* Saves the offset value used to construct OFFSET clause.
138+
*/
139+
def withOffset(offset: Int): JdbcSQLQueryBuilder = {
140+
this.offset = offset
141+
142+
this
143+
}
144+
145+
/**
146+
* Constructs the table sample clause that following dialect's SQL syntax.
147+
*/
148+
def withTableSample(sample: TableSampleInfo): JdbcSQLQueryBuilder = {
149+
tableSampleClause = dialect.getTableSample(sample)
150+
151+
this
152+
}
153+
154+
/**
155+
* Build the final SQL query that following dialect's SQL syntax.
156+
*/
157+
def build(): String = {
158+
// Constructs the LIMIT clause that following dialect's SQL syntax.
159+
val limitClause = dialect.getLimitClause(limit)
160+
// Constructs the OFFSET clause that following dialect's SQL syntax.
161+
val offsetClause = dialect.getOffsetClause(offset)
162+
163+
options.prepareQuery +
164+
s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
165+
s" $whereClause $groupByClause $orderByClause $limitClause $offsetClause"
166+
}
167+
}

0 commit comments

Comments
 (0)