Skip to content

[SPARK-39139][SQL] DS V2 supports push down DS V2 UDF #36593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;

/**
* The general representation of SQL scalar expressions, which contains the upper-cased
Expand Down Expand Up @@ -249,12 +249,7 @@ public int hashCode() {

@Override
public String toString() {
V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder();
try {
return builder.build(this);
} catch (Throwable e) {
return name + "(" +
Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's wrong with the previous code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code let the toString display as Option(...).

}
ToStringSQLBuilder builder = new ToStringSQLBuilder();
return builder.build(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;

/**
* The general representation of user defined scalar function, which contains the upper-cased
* function name, canonical function name and all the children expressions.
*
* @since 3.4.0
*/
@Evolving
public class UserDefinedScalarFunc implements Expression, Serializable {
private String name;
private String canonicalName;
private Expression[] children;

public UserDefinedScalarFunc(String name, String canonicalName, Expression[] children) {
this.name = name;
this.canonicalName = canonicalName;
this.children = children;
}

public String name() { return name; }
public String canonicalName() { return canonicalName; }

@Override
public Expression[] children() { return children; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
Arrays.equals(children, that.children);
}

@Override
public int hashCode() {
return Objects.hash(name, canonicalName, children);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@beliefer , should it be Arrays.hashCode(children)?

}

@Override
public String toString() {
ToStringSQLBuilder builder = new ToStringSQLBuilder();
return builder.build(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.sql.connector.expressions.aggregate;

import java.util.Arrays;
import java.util.stream.Collectors;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;

/**
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function
Expand All @@ -47,27 +45,21 @@ public final class GeneralAggregateFunc implements AggregateFunc {
private final boolean isDistinct;
private final Expression[] children;

public String name() { return name; }
public boolean isDistinct() { return isDistinct; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just put these get-like method together.


public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
}

public String name() { return name; }
public boolean isDistinct() { return isDistinct; }

@Override
public Expression[] children() { return children; }

@Override
public String toString() {
String inputsString = Arrays.stream(children)
.map(Expression::describe)
.collect(Collectors.joining(", "));
if (isDistinct) {
return name + "(DISTINCT " + inputsString + ")";
} else {
return name + "(" + inputsString + ")";
}
ToStringSQLBuilder builder = new ToStringSQLBuilder();
return builder.build(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;

/**
* The general representation of user defined aggregate function, which implements
* {@link AggregateFunc}, contains the upper-cased function name, the canonical function name,
* the `isDistinct` flag and all the inputs. Note that Spark cannot push down aggregate with
* this function partially to the source, but can only push down the entire aggregate.
*
* @since 3.4.0
*/
@Evolving
public class UserDefinedAggregateFunc implements AggregateFunc {
private final String name;
private String canonicalName;
private final boolean isDistinct;
private final Expression[] children;

public UserDefinedAggregateFunc(
String name, String canonicalName, boolean isDistinct, Expression[] children) {
this.name = name;
this.canonicalName = canonicalName;
this.isDistinct = isDistinct;
this.children = children;
}

public String name() { return name; }
public String canonicalName() { return canonicalName; }
public boolean isDistinct() { return isDistinct; }

@Override
public Expression[] children() { return children; }

@Override
public String toString() {
ToStringSQLBuilder builder = new ToStringSQLBuilder();
return builder.build(this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
import org.apache.spark.sql.types.DataType;

/**
Expand Down Expand Up @@ -134,6 +137,18 @@ public String build(Expression expr) {
default:
return visitUnexpectedExpr(expr);
}
} else if (expr instanceof GeneralAggregateFunc) {
GeneralAggregateFunc f = (GeneralAggregateFunc) expr;
return visitGeneralAggregateFunction(f.name(), f.isDistinct(),
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof UserDefinedScalarFunc) {
UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr;
return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
} else if (expr instanceof UserDefinedAggregateFunc) {
UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr;
return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(),
Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
} else {
return visitUnexpectedExpr(expr);
}
Expand Down Expand Up @@ -246,6 +261,28 @@ protected String visitSQLFunction(String funcName, String[] inputs) {
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
}

protected String visitGeneralAggregateFunction(
String funcName, boolean isDistinct, String[] inputs) {
if (isDistinct) {
return funcName +
"(DISTINCT " + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
} else {
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
}
}

protected String visitUserDefinedScalarFunction(
String funcName, String canonicalName, String[] inputs) {
throw new UnsupportedOperationException(
this.getClass().getSimpleName() + " does not support user defined function: " + funcName);
}

protected String visitUserDefinedAggregateFunction(
String funcName, String canonicalName, boolean isDistinct, String[] inputs) {
throw new UnsupportedOperationException(this.getClass().getSimpleName() +
" does not support user defined aggregate function: " + funcName);
}

protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.internal.connector

import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder

/**
* The builder to generate `toString` information of V2 expressions.
*/
class ToStringSQLBuilder extends V2ExpressionSQLBuilder {
override protected def visitUserDefinedScalarFunction(
funcName: String, canonicalName: String, inputs: Array[String]) =
s"""$funcName(${inputs.mkString(", ")})"""

override protected def visitUserDefinedAggregateFunction(
funcName: String,
canonicalName: String,
isDistinct: Boolean,
inputs: Array[String]): String = {
val distinct = if (isDistinct) "DISTINCT " else ""
s"""$funcName($distinct${inputs.mkString(", ")})"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.types.BooleanType

Expand Down Expand Up @@ -283,6 +283,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
None
}
// TODO supports other expressions
case ApplyFunctionExpression(function, children) =>
val childrenExpressions = children.flatMap(generateExpression(_))
if (childrenExpressions.length == children.length) {
Some(new UserDefinedScalarFunc(
function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression]))
} else {
None
}
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, V2ExpressionBu
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -750,6 +750,14 @@ object DataSourceStrategy
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
Array(FieldReference.column(left), FieldReference.column(right))))
case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
val translatedExprs = children.flatMap(PushableExpression.unapply(_))
if (translatedExprs.length == children.length) {
Some(new UserDefinedAggregateFunc(aggrFunc.name(),
aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression]))
} else {
None
}
case _ => None
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
Expand Down Expand Up @@ -299,6 +299,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
case count: Count => !count.isDistinct
case avg: Avg => !avg.isDistinct
case _: GeneralAggregateFunc => false
case _: UserDefinedAggregateFunc => false
case _ => true
}
}
Expand Down
Loading