Skip to content

Commit bf75b49

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-38855][SQL] DS V2 supports push down math functions
### What changes were proposed in this pull request? Currently, Spark have some math functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L388 These functions show below: `LN`, `EXP`, `POWER`, `SQRT`, `FLOOR`, `CEIL`, `WIDTH_BUCKET` The mainstream databases support these functions show below. | 函数 | PostgreSQL | ClickHouse | H2 | MySQL | Oracle | Redshift | Presto | Teradata | Snowflake | DB2 | Vertica | Exasol | SqlServer | Yellowbrick | Impala | Mariadb | Druid | Pig | SQLite | Influxdata | Singlestore | ElasticSearch | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | `LN` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `EXP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `POWER` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | | `SQRT` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `FLOOR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `CEIL` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `WIDTH_BUCKET` | Yes | No | No | No | Yes | No | Yes | Yes | Yes | Yes | Yes | No | No | No | Yes | No | No | No | No | No | No | No | DS V2 should supports push down these math functions. ### Why are the changes needed? DS V2 supports push down math functions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36140 from beliefer/SPARK-38855. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e2683c2 commit bf75b49

File tree

6 files changed

+145
-2
lines changed

6 files changed

+145
-2
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,60 @@
9494
* <li>Since version: 3.3.0</li>
9595
* </ul>
9696
* </li>
97+
* <li>Name: <code>ABS</code>
98+
* <ul>
99+
* <li>SQL semantic: <code>ABS(expr)</code></li>
100+
* <li>Since version: 3.3.0</li>
101+
* </ul>
102+
* </li>
103+
* <li>Name: <code>COALESCE</code>
104+
* <ul>
105+
* <li>SQL semantic: <code>COALESCE(expr1, expr2)</code></li>
106+
* <li>Since version: 3.3.0</li>
107+
* </ul>
108+
* </li>
109+
* <li>Name: <code>LN</code>
110+
* <ul>
111+
* <li>SQL semantic: <code>LN(expr)</code></li>
112+
* <li>Since version: 3.3.0</li>
113+
* </ul>
114+
* </li>
115+
* <li>Name: <code>EXP</code>
116+
* <ul>
117+
* <li>SQL semantic: <code>EXP(expr)</code></li>
118+
* <li>Since version: 3.3.0</li>
119+
* </ul>
120+
* </li>
121+
* <li>Name: <code>POWER</code>
122+
* <ul>
123+
* <li>SQL semantic: <code>POWER(expr, number)</code></li>
124+
* <li>Since version: 3.3.0</li>
125+
* </ul>
126+
* </li>
127+
* <li>Name: <code>SQRT</code>
128+
* <ul>
129+
* <li>SQL semantic: <code>SQRT(expr)</code></li>
130+
* <li>Since version: 3.3.0</li>
131+
* </ul>
132+
* </li>
133+
* <li>Name: <code>FLOOR</code>
134+
* <ul>
135+
* <li>SQL semantic: <code>FLOOR(expr)</code></li>
136+
* <li>Since version: 3.3.0</li>
137+
* </ul>
138+
* </li>
139+
* <li>Name: <code>CEIL</code>
140+
* <ul>
141+
* <li>SQL semantic: <code>CEIL(expr)</code></li>
142+
* <li>Since version: 3.3.0</li>
143+
* </ul>
144+
* </li>
145+
* <li>Name: <code>WIDTH_BUCKET</code>
146+
* <ul>
147+
* <li>SQL semantic: <code>WIDTH_BUCKET(expr)</code></li>
148+
* <li>Since version: 3.3.0</li>
149+
* </ul>
150+
* </li>
97151
* </ol>
98152
* Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
99153
* including: add, subtract, multiply, divide, remainder, pmod.

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ public String build(Expression expr) {
9595
return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
9696
case "ABS":
9797
case "COALESCE":
98+
case "LN":
99+
case "EXP":
100+
case "POWER":
101+
case "SQRT":
102+
case "FLOOR":
103+
case "CEIL":
104+
case "WIDTH_BUCKET":
98105
return visitSQLFunction(name,
99106
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
100107
case "CASE_WHEN": {

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,4 +2384,8 @@ object QueryCompilationErrors {
23842384
new AnalysisException(
23852385
"Sinks cannot request distribution and ordering in continuous execution mode")
23862386
}
2387+
2388+
def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
2389+
new AnalysisException(s"$database does not support function: $funcInfo")
2390+
}
23872391
}

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
20+
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
2121
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
2222
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
2323
import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -104,6 +104,32 @@ class V2ExpressionBuilder(
104104
} else {
105105
None
106106
}
107+
case Log(child) => generateExpression(child)
108+
.map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
109+
case Exp(child) => generateExpression(child)
110+
.map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
111+
case Pow(left, right) =>
112+
val l = generateExpression(left)
113+
val r = generateExpression(right)
114+
if (l.isDefined && r.isDefined) {
115+
Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get)))
116+
} else {
117+
None
118+
}
119+
case Sqrt(child) => generateExpression(child)
120+
.map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
121+
case Floor(child) => generateExpression(child)
122+
.map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
123+
case Ceil(child) => generateExpression(child)
124+
.map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
125+
case wb: WidthBucket =>
126+
val childrenExpressions = wb.children.flatMap(generateExpression(_))
127+
if (childrenExpressions.length == wb.children.length) {
128+
Some(new GeneralScalarExpression("WIDTH_BUCKET",
129+
childrenExpressions.toArray[V2Expression]))
130+
} else {
131+
None
132+
}
107133
case and: And =>
108134
// AND expects predicate
109135
val l = generateExpression(and.left, true)

sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,40 @@ package org.apache.spark.sql.jdbc
2020
import java.sql.SQLException
2121
import java.util.Locale
2222

23+
import scala.util.control.NonFatal
24+
2325
import org.apache.spark.sql.AnalysisException
2426
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
27+
import org.apache.spark.sql.connector.expressions.Expression
2528
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
29+
import org.apache.spark.sql.errors.QueryCompilationErrors
2630

2731
private object H2Dialect extends JdbcDialect {
2832
override def canHandle(url: String): Boolean =
2933
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
3034

35+
class H2SQLBuilder extends JDBCSQLBuilder {
36+
override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
37+
funcName match {
38+
case "WIDTH_BUCKET" =>
39+
val functionInfo = super.visitSQLFunction(funcName, inputs)
40+
throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo)
41+
case _ => super.visitSQLFunction(funcName, inputs)
42+
}
43+
}
44+
}
45+
46+
override def compileExpression(expr: Expression): Option[String] = {
47+
val h2SQLBuilder = new H2SQLBuilder()
48+
try {
49+
Some(h2SQLBuilder.build(expr))
50+
} catch {
51+
case NonFatal(e) =>
52+
logWarning("Error occurs while compiling V2 expression", e)
53+
None
54+
}
55+
}
56+
3157
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
3258
super.compileAggregate(aggFunction).orElse(
3359
aggFunction match {

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
2626
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort}
2727
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
2828
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
29-
import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when}
29+
import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when}
3030
import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.test.SharedSparkSession
3232
import org.apache.spark.util.Utils
@@ -464,6 +464,32 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
464464
checkPushedInfo(df5, expectedPlanFragment5)
465465
checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
466466
Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true)))
467+
468+
val df6 = spark.table("h2.test.employee")
469+
.filter(ln($"dept") > 1)
470+
.filter(exp($"salary") > 2000)
471+
.filter(pow($"dept", 2) > 4)
472+
.filter(sqrt($"salary") > 100)
473+
.filter(floor($"dept") > 1)
474+
.filter(ceil($"dept") > 1)
475+
checkFiltersRemoved(df6, ansiMode)
476+
val expectedPlanFragment6 = if (ansiMode) {
477+
"PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " +
478+
"LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...,"
479+
} else {
480+
"PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]"
481+
}
482+
checkPushedInfo(df6, expectedPlanFragment6)
483+
checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true)))
484+
485+
// H2 does not support width_bucket
486+
val df7 = sql("""
487+
|SELECT * FROM h2.test.employee
488+
|WHERE width_bucket(dept, 1, 6, 3) > 1
489+
|""".stripMargin)
490+
checkFiltersRemoved(df7, false)
491+
checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]")
492+
checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true)))
467493
}
468494
}
469495
}

0 commit comments

Comments
 (0)