Skip to content

Commit b242f85

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-39139][SQL] DS V2 supports push down DS V2 UDF
### What changes were proposed in this pull request? Currently, Spark DS V2 push-down framework supports push down SQL to data sources. But the DS V2 push-down framework only support push down the built-in functions to data sources. Each database have a lot very useful functions which not supported by Spark. If we can push down these functions into data source, it will reduce disk I/O and network I/O and improve the performance when query databases. ### Why are the changes needed? 1. Spark can leverage the functions supported by databases 2. Improve the query performance. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New tests. Closes #36593 from beliefer/SPARK-39139. Authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent c1106fb commit b242f85

File tree

11 files changed

+422
-53
lines changed

11 files changed

+422
-53
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import org.apache.spark.annotation.Evolving;
2525
import org.apache.spark.sql.connector.expressions.filter.Predicate;
26-
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
26+
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
2727

2828
/**
2929
* The general representation of SQL scalar expressions, which contains the upper-cased
@@ -381,12 +381,7 @@ public int hashCode() {
381381

382382
@Override
383383
public String toString() {
384-
V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder();
385-
try {
386-
return builder.build(this);
387-
} catch (Throwable e) {
388-
return name + "(" +
389-
Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")";
390-
}
384+
ToStringSQLBuilder builder = new ToStringSQLBuilder();
385+
return builder.build(this);
391386
}
392387
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.expressions;
19+
20+
import java.io.Serializable;
21+
import java.util.Arrays;
22+
import java.util.Objects;
23+
24+
import org.apache.spark.annotation.Evolving;
25+
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
26+
27+
/**
28+
* The general representation of user defined scalar function, which contains the upper-cased
29+
* function name, canonical function name and all the children expressions.
30+
*
31+
* @since 3.4.0
32+
*/
33+
@Evolving
34+
public class UserDefinedScalarFunc implements Expression, Serializable {
35+
private String name;
36+
private String canonicalName;
37+
private Expression[] children;
38+
39+
public UserDefinedScalarFunc(String name, String canonicalName, Expression[] children) {
40+
this.name = name;
41+
this.canonicalName = canonicalName;
42+
this.children = children;
43+
}
44+
45+
public String name() { return name; }
46+
public String canonicalName() { return canonicalName; }
47+
48+
@Override
49+
public Expression[] children() { return children; }
50+
51+
@Override
52+
public boolean equals(Object o) {
53+
if (this == o) return true;
54+
if (o == null || getClass() != o.getClass()) return false;
55+
UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
56+
return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
57+
Arrays.equals(children, that.children);
58+
}
59+
60+
@Override
61+
public int hashCode() {
62+
return Objects.hash(name, canonicalName, children);
63+
}
64+
65+
@Override
66+
public String toString() {
67+
ToStringSQLBuilder builder = new ToStringSQLBuilder();
68+
return builder.build(this);
69+
}
70+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

20-
import java.util.Arrays;
21-
import java.util.stream.Collectors;
22-
2320
import org.apache.spark.annotation.Evolving;
2421
import org.apache.spark.sql.connector.expressions.Expression;
22+
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
2523

2624
/**
2725
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function
@@ -47,27 +45,21 @@ public final class GeneralAggregateFunc implements AggregateFunc {
4745
private final boolean isDistinct;
4846
private final Expression[] children;
4947

50-
public String name() { return name; }
51-
public boolean isDistinct() { return isDistinct; }
52-
5348
public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
5449
this.name = name;
5550
this.isDistinct = isDistinct;
5651
this.children = children;
5752
}
5853

54+
public String name() { return name; }
55+
public boolean isDistinct() { return isDistinct; }
56+
5957
@Override
6058
public Expression[] children() { return children; }
6159

6260
@Override
6361
public String toString() {
64-
String inputsString = Arrays.stream(children)
65-
.map(Expression::describe)
66-
.collect(Collectors.joining(", "));
67-
if (isDistinct) {
68-
return name + "(DISTINCT " + inputsString + ")";
69-
} else {
70-
return name + "(" + inputsString + ")";
71-
}
62+
ToStringSQLBuilder builder = new ToStringSQLBuilder();
63+
return builder.build(this);
7264
}
7365
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.expressions.aggregate;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.expressions.Expression;
22+
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
23+
24+
/**
25+
* The general representation of user defined aggregate function, which implements
26+
* {@link AggregateFunc}, contains the upper-cased function name, the canonical function name,
27+
* the `isDistinct` flag and all the inputs. Note that Spark cannot push down aggregate with
28+
* this function partially to the source, but can only push down the entire aggregate.
29+
*
30+
* @since 3.4.0
31+
*/
32+
@Evolving
33+
public class UserDefinedAggregateFunc implements AggregateFunc {
34+
private final String name;
35+
private String canonicalName;
36+
private final boolean isDistinct;
37+
private final Expression[] children;
38+
39+
public UserDefinedAggregateFunc(
40+
String name, String canonicalName, boolean isDistinct, Expression[] children) {
41+
this.name = name;
42+
this.canonicalName = canonicalName;
43+
this.isDistinct = isDistinct;
44+
this.children = children;
45+
}
46+
47+
public String name() { return name; }
48+
public String canonicalName() { return canonicalName; }
49+
public boolean isDistinct() { return isDistinct; }
50+
51+
@Override
52+
public Expression[] children() { return children; }
53+
54+
@Override
55+
public String toString() {
56+
ToStringSQLBuilder builder = new ToStringSQLBuilder();
57+
return builder.build(this);
58+
}
59+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import org.apache.spark.sql.connector.expressions.NamedReference;
2727
import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
2828
import org.apache.spark.sql.connector.expressions.Literal;
29+
import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
30+
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
31+
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
2932
import org.apache.spark.sql.types.DataType;
3033

3134
/**
@@ -156,6 +159,18 @@ public String build(Expression expr) {
156159
default:
157160
return visitUnexpectedExpr(expr);
158161
}
162+
} else if (expr instanceof GeneralAggregateFunc) {
163+
GeneralAggregateFunc f = (GeneralAggregateFunc) expr;
164+
return visitGeneralAggregateFunction(f.name(), f.isDistinct(),
165+
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
166+
} else if (expr instanceof UserDefinedScalarFunc) {
167+
UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr;
168+
return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
169+
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
170+
} else if (expr instanceof UserDefinedAggregateFunc) {
171+
UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr;
172+
return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(),
173+
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
159174
} else {
160175
return visitUnexpectedExpr(expr);
161176
}
@@ -268,6 +283,28 @@ protected String visitSQLFunction(String funcName, String[] inputs) {
268283
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
269284
}
270285

286+
protected String visitGeneralAggregateFunction(
287+
String funcName, boolean isDistinct, String[] inputs) {
288+
if (isDistinct) {
289+
return funcName +
290+
"(DISTINCT " + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
291+
} else {
292+
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
293+
}
294+
}
295+
296+
protected String visitUserDefinedScalarFunction(
297+
String funcName, String canonicalName, String[] inputs) {
298+
throw new UnsupportedOperationException(
299+
this.getClass().getSimpleName() + " does not support user defined function: " + funcName);
300+
}
301+
302+
protected String visitUserDefinedAggregateFunction(
303+
String funcName, String canonicalName, boolean isDistinct, String[] inputs) {
304+
throw new UnsupportedOperationException(this.getClass().getSimpleName() +
305+
" does not support user defined aggregate function: " + funcName);
306+
}
307+
271308
protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
272309
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
273310
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.internal.connector
19+
20+
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
21+
22+
/**
23+
* The builder to generate `toString` information of V2 expressions.
24+
*/
25+
class ToStringSQLBuilder extends V2ExpressionSQLBuilder {
26+
override protected def visitUserDefinedScalarFunction(
27+
funcName: String, canonicalName: String, inputs: Array[String]) =
28+
s"""$funcName(${inputs.mkString(", ")})"""
29+
30+
override protected def visitUserDefinedAggregateFunction(
31+
funcName: String,
32+
canonicalName: String,
33+
isDistinct: Boolean,
34+
inputs: Array[String]): String = {
35+
val distinct = if (isDistinct) "DISTINCT " else ""
36+
s"""$funcName($distinct${inputs.mkString(", ")})"""
37+
}
38+
}

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.util
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
21+
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
2222
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
2323
import org.apache.spark.sql.types.BooleanType
2424

@@ -345,6 +345,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
345345
None
346346
}
347347
// TODO supports other expressions
348+
case ApplyFunctionExpression(function, children) =>
349+
val childrenExpressions = children.flatMap(generateExpression(_))
350+
if (childrenExpressions.length == children.length) {
351+
Some(new UserDefinedScalarFunc(
352+
function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression]))
353+
} else {
354+
None
355+
}
348356
case _ => None
349357
}
350358
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, V2ExpressionBu
4242
import org.apache.spark.sql.connector.catalog.SupportsRead
4343
import org.apache.spark.sql.connector.catalog.TableCapability._
4444
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
45-
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
45+
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
4646
import org.apache.spark.sql.errors.QueryCompilationErrors
4747
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
4848
import org.apache.spark.sql.execution.command._
@@ -751,6 +751,14 @@ object DataSourceStrategy
751751
PushableColumnWithoutNestedColumn(right), _) =>
752752
Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
753753
Array(FieldReference.column(left), FieldReference.column(right))))
754+
case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
755+
val translatedExprs = children.flatMap(PushableExpression.unapply(_))
756+
if (translatedExprs.length == children.length) {
757+
Some(new UserDefinedAggregateFunc(aggrFunc.name(),
758+
aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression]))
759+
} else {
760+
None
761+
}
754762
case _ => None
755763
}
756764
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
2626
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
2727
import org.apache.spark.sql.catalyst.rules.Rule
2828
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
29-
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
29+
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc}
3030
import org.apache.spark.sql.connector.expressions.filter.Predicate
3131
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
3232
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -299,6 +299,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
299299
case count: Count => !count.isDistinct
300300
case avg: Avg => !avg.isDistinct
301301
case _: GeneralAggregateFunc => false
302+
case _: UserDefinedAggregateFunc => false
302303
case _ => true
303304
}
304305
}

0 commit comments

Comments
 (0)