Skip to content

Commit a730da9

Browse files
belieferchenzhx
authored andcommitted
[SPARK-38761][SQL] DS V2 supports push down misc non-aggregate functions
### What changes were proposed in this pull request? Currently, Spark have some misc non-aggregate functions of ANSI standard. Please refer https://github.com/apache/spark/blob/2f8613f22c0750c00cf1dcfb2f31c431d8dc1be7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala#L362. These functions show below: `abs`, `coalesce`, `nullif`, `CASE WHEN` DS V2 should supports push down these misc non-aggregate functions. Because DS V2 already support push down `CASE WHEN`, so this PR no need do the job again. Because `nullif` extends `RuntimeReplaceable`, so this PR no need do the job too. ### Why are the changes needed? DS V2 supports push down misc non-aggregate functions ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes apache#36039 from beliefer/SPARK-38761. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 93690a0 commit a730da9

File tree

3 files changed

+44
-25
lines changed

3 files changed

+44
-25
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ public String build(Expression expr) {
9393
return visitNot(build(e.children()[0]));
9494
case "~":
9595
return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
96+
case "ABS":
97+
case "COALESCE":
98+
return visitSQLFunction(name,
99+
Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
96100
case "CASE_WHEN": {
97101
List<String> children =
98102
Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
@@ -210,6 +214,10 @@ protected String visitCaseWhen(String[] children) {
210214
return sb.toString();
211215
}
212216

217+
protected String visitSQLFunction(String funcName, String[] inputs) {
218+
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
219+
}
220+
213221
protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
214222
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
215223
}

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
20+
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Coalesce, Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull, Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate, Subtract, UnaryMinus}
2121
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
2222
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
2323
import org.apache.spark.sql.execution.datasources.PushableColumn
@@ -96,6 +96,15 @@ class V2ExpressionBuilder(
9696
}
9797
case Cast(child, dataType, _, true) =>
9898
generateExpression(child).map(v => new V2Cast(v, dataType))
99+
case Abs(child, true) => generateExpression(child)
100+
.map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v)))
101+
case Coalesce(children) =>
102+
val childrenExpressions = children.flatMap(generateExpression(_))
103+
if (children.length == childrenExpressions.length) {
104+
Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression]))
105+
} else {
106+
None
107+
}
99108
case and: And =>
100109
// AND expects predicate
101110
val l = generateExpression(and.left, true)

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
2626
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
2727
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
2828
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
29-
import org.apache.spark.sql.functions.{avg, count, count_distinct, lit, not, sum, udf, when}
29+
import org.apache.spark.sql.functions.{abs, avg, coalesce, count, count_distinct, lit, not, sum, udf, when}
3030
import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.test.SharedSparkSession
3232
import org.apache.spark.util.Utils
@@ -381,19 +381,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
381381
checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2)))
382382

383383
val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1)
384-
385384
checkFiltersRemoved(df2, ansiMode)
386-
387-
df2.queryExecution.optimizedPlan.collect {
388-
case _: DataSourceV2ScanRelation =>
389-
val expected_plan_fragment = if (ansiMode) {
390-
"PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
391-
} else {
392-
"PushedFilters: [ID IS NOT NULL], "
393-
}
394-
checkKeywordsExistsInExplain(df2, expected_plan_fragment)
385+
val expectedPlanFragment2 = if (ansiMode) {
386+
"PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
387+
} else {
388+
"PushedFilters: [ID IS NOT NULL], "
395389
}
396-
390+
checkPushedInfo(df2, expectedPlanFragment2)
397391
if (ansiMode) {
398392
val e = intercept[SparkException] {
399393
checkAnswer(df2, Seq.empty)
@@ -422,22 +416,30 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
422416

423417
val df4 = spark.table("h2.test.employee")
424418
.filter(($"salary" > 1000d).and($"salary" < 12000d))
425-
426419
checkFiltersRemoved(df4, ansiMode)
427-
428-
df4.queryExecution.optimizedPlan.collect {
429-
case _: DataSourceV2ScanRelation =>
430-
val expected_plan_fragment = if (ansiMode) {
431-
"PushedFilters: [SALARY IS NOT NULL, " +
432-
"CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], "
433-
} else {
434-
"PushedFilters: [SALARY IS NOT NULL], "
435-
}
436-
checkKeywordsExistsInExplain(df4, expected_plan_fragment)
420+
val expectedPlanFragment4 = if (ansiMode) {
421+
"PushedFilters: [SALARY IS NOT NULL, " +
422+
"CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], "
423+
} else {
424+
"PushedFilters: [SALARY IS NOT NULL], "
437425
}
438-
426+
checkPushedInfo(df4, expectedPlanFragment4)
439427
checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true),
440428
Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true)))
429+
430+
val df5 = spark.table("h2.test.employee")
431+
.filter(abs($"dept" - 3) > 1)
432+
.filter(coalesce($"salary", $"bonus") > 2000)
433+
checkFiltersRemoved(df5, ansiMode)
434+
val expectedPlanFragment5 = if (ansiMode) {
435+
"PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " +
436+
"(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]"
437+
} else {
438+
"PushedFilters: [DEPT IS NOT NULL]"
439+
}
440+
checkPushedInfo(df5, expectedPlanFragment5)
441+
checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
442+
Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true)))
441443
}
442444
}
443445
}

0 commit comments

Comments
 (0)