Skip to content

[SPARK-39162][SQL] Jdbc dialect should decide which function could be pushed down #36521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2404,10 +2404,6 @@ object QueryCompilationErrors extends QueryErrorsBase {
"Sinks cannot request distribution and ordering in continuous execution mode")
}

def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
new AnalysisException(s"$database does not support function: $funcInfo")
}

// Return a more descriptive error message if the user tries to nest a DEFAULT column reference
// inside some other expression (such as DEFAULT + 1) in an INSERT INTO command's VALUES list;
// this is not allowed.
Expand Down
28 changes: 4 additions & 24 deletions sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,21 @@ package org.apache.spark.sql.jdbc
import java.sql.{SQLException, Types}
import java.util.Locale

import scala.util.control.NonFatal

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType}

private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

class H2SQLBuilder extends JDBCSQLBuilder {
override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
funcName match {
case "WIDTH_BUCKET" =>
val functionInfo = super.visitSQLFunction(funcName, inputs)
throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo)
case _ => super.visitSQLFunction(funcName, inputs)
}
}
}
private val supportedFunctions =
Set("ABS", "COALESCE", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL")

override def compileExpression(expr: Expression): Option[String] = {
val h2SQLBuilder = new H2SQLBuilder()
try {
Some(h2SQLBuilder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression", e)
None
}
}
override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
super.compileAggregate(aggFunction).orElse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,27 @@ abstract class JdbcDialect extends Serializable with Logging{
getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName)
s"CAST($l AS $databaseTypeDefinition)"
}

override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
if (isSupportedFunction(funcName)) {
s"""$funcName(${inputs.mkString(", ")})"""
} else {
// The framework will catch the error and give up the push-down.
// Please see `JdbcDialect.compileExpression(expr: Expression)` for more details.
throw new UnsupportedOperationException(
s"${this.getClass.getSimpleName} does not support function: $funcName")
}
}
}

/**
* Returns whether the database supports function.
* @param funcName Upper-cased function name
* @return True if the database supports function.
*/
@Since("3.3.0")
def isSupportedFunction(funcName: String): Boolean = false

/**
* Converts V2 expression to String representing a SQL expression.
* @param expr The V2 expression to be converted.
Expand Down