Skip to content

[SPARK-39270][SQL] JDBC dialect supports registering dialect specific functions #36649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange}
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException
import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JDBCRDD, JdbcUtils}
Expand All @@ -32,10 +34,14 @@ import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging {
class JDBCTableCatalog extends TableCatalog
with SupportsNamespaces with FunctionCatalog with Logging {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

private var catalogName: String = null
private var options: JDBCOptions = _
private var dialect: JdbcDialect = _
private var functions: Map[String, UnboundFunction] = _

override def name(): String = {
require(catalogName != null, "The JDBC table catalog is not initialed")
Expand All @@ -52,6 +58,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
// fake value, so that it can pass the check of `JDBCOptions`.
this.options = new JDBCOptions(map + (JDBCOptions.JDBC_TABLE_NAME -> "__invalid_dbtable"))
dialect = JdbcDialects.get(this.options.url)
functions = dialect.functions.toMap
}

override def listTables(namespace: Array[String]): Array[Identifier] = {
Expand Down Expand Up @@ -297,4 +304,24 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
private def getTableName(ident: Identifier): String = {
(ident.namespace() :+ ident.name()).map(dialect.quoteIdentifier).mkString(".")
}

override def listFunctions(namespace: Array[String]): Array[Identifier] = {
if (namespace.isEmpty) {
functions.keys.map(Identifier.of(namespace, _)).toArray
} else {
Array.empty[Identifier]
}
}

override def loadFunction(ident: Identifier): UnboundFunction = {
if (ident.namespace().nonEmpty) {
throw QueryCompilationErrors.noSuchFunctionError(ident.asFunctionIdentifier)
}
functions.get(ident.name()) match {
case Some(func) =>
func
case _ =>
throw new NoSuchFunctionException(ident)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@ package org.apache.spark.sql.jdbc

import java.sql.{SQLException, Types}
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, ShortType, StringType}

private object H2Dialect extends JdbcDialect {
private[sql] object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")

Expand Down Expand Up @@ -82,6 +86,21 @@ private object H2Dialect extends JdbcDialect {
case _ => JdbcUtils.getCommonJDBCType(dt)
}

private val functionMap: java.util.Map[String, UnboundFunction] =
new ConcurrentHashMap[String, UnboundFunction]()

// test only
def registerFunction(name: String, fn: UnboundFunction): UnboundFunction = {
functionMap.put(name, fn)
}

override def functions: Seq[(String, UnboundFunction)] = functionMap.asScala.toSeq

// test only
def clearFunctions(): Unit = {
functionMap.clear()
}

override def classifyException(message: String, e: Throwable): AnalysisException = {
e match {
case exception: SQLException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
Expand Down Expand Up @@ -323,6 +324,12 @@ abstract class JdbcDialect extends Serializable with Logging{
}
}

/**
* List the user-defined functions in jdbc dialect.
* @return a sequence of tuple from function name to user-defined function.
*/
def functions: Seq[(String, UnboundFunction)] = Nil

/**
* Create schema with an optional comment. Empty string means no comment.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,84 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

object IntAverage extends AggregateFunction[(Int, Int), Int] {
override def name(): String = "iavg"
override def inputTypes(): Array[DataType] = Array(IntegerType)
override def resultType(): DataType = IntegerType

override def newAggregationState(): (Int, Int) = (0, 0)

override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
if (input.isNullAt(0)) {
state
} else {
val i = input.getInt(0)
state match {
case (_, 0) =>
(i, 1)
case (total, count) =>
(total + i, count + 1)
}
}
}

override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
(leftState._1 + rightState._1, leftState._2 + rightState._2)
}

override def produceResult(state: (Int, Int)): Int = state._1 / state._2
}

object LongAverage extends AggregateFunction[(Long, Long), Long] {
override def name(): String = "iavg"
override def inputTypes(): Array[DataType] = Array(LongType)
override def resultType(): DataType = LongType

override def newAggregationState(): (Long, Long) = (0L, 0L)

override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
if (input.isNullAt(0)) {
state
} else {
val l = input.getLong(0)
state match {
case (_, 0L) =>
(l, 1)
case (total, count) =>
(total + l, count + 1L)
}
}
}

override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
(leftState._1 + rightState._1, leftState._2 + rightState._2)
}

override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}

object IntegralAverage extends UnboundFunction {
override def name(): String = "iavg"

override def bind(inputType: StructType): BoundFunction = {
if (inputType.fields.length > 1) {
throw new UnsupportedOperationException("Too many arguments")
}

inputType.fields(0).dataType match {
case _: IntegerType => IntAverage
case _: LongType => LongAverage
case dataType =>
throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
}
}

override def description(): String =
"""iavg: produces an average using integer division, ignoring nulls
| iavg(int) -> int
| iavg(bigint) -> bigint""".stripMargin
}

class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String]

Expand Down Expand Up @@ -537,84 +615,6 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
override def name(): String = "bad_bound_func"
}

object IntegralAverage extends UnboundFunction {
override def name(): String = "iavg"

override def bind(inputType: StructType): BoundFunction = {
if (inputType.fields.length > 1) {
throw new UnsupportedOperationException("Too many arguments")
}

inputType.fields(0).dataType match {
case _: IntegerType => IntAverage
case _: LongType => LongAverage
case dataType =>
throw new UnsupportedOperationException(s"Unsupported non-integral type: $dataType")
}
}

override def description(): String =
"""iavg: produces an average using integer division, ignoring nulls
| iavg(int) -> int
| iavg(bigint) -> bigint""".stripMargin
}

object IntAverage extends AggregateFunction[(Int, Int), Int] {
override def name(): String = "iavg"
override def inputTypes(): Array[DataType] = Array(IntegerType)
override def resultType(): DataType = IntegerType

override def newAggregationState(): (Int, Int) = (0, 0)

override def update(state: (Int, Int), input: InternalRow): (Int, Int) = {
if (input.isNullAt(0)) {
state
} else {
val i = input.getInt(0)
state match {
case (_, 0) =>
(i, 1)
case (total, count) =>
(total + i, count + 1)
}
}
}

override def merge(leftState: (Int, Int), rightState: (Int, Int)): (Int, Int) = {
(leftState._1 + rightState._1, leftState._2 + rightState._2)
}

override def produceResult(state: (Int, Int)): Int = state._1 / state._2
}

object LongAverage extends AggregateFunction[(Long, Long), Long] {
override def name(): String = "iavg"
override def inputTypes(): Array[DataType] = Array(LongType)
override def resultType(): DataType = LongType

override def newAggregationState(): (Long, Long) = (0L, 0L)

override def update(state: (Long, Long), input: InternalRow): (Long, Long) = {
if (input.isNullAt(0)) {
state
} else {
val l = input.getLong(0)
state match {
case (_, 0L) =>
(l, 1)
case (total, count) =>
(total + l, count + 1L)
}
}
}

override def merge(leftState: (Long, Long), rightState: (Long, Long)): (Long, Long) = {
(leftState._1 + rightState._1, leftState._2 + rightState._2)
}

override def produceResult(state: (Long, Long)): Long = state._1 / state._2
}

object UnboundDecimalAverage extends UnboundFunction {
override def name(): String = "decimal_avg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import java.sql.{Connection, DriverManager}
import java.util.Properties

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Sort}
import org.apache.spark.sql.connector.IntegralAverage
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when}
Expand Down Expand Up @@ -104,9 +105,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
"(1, 'bottle', 99999999999999999999.123)").executeUpdate()
}
H2Dialect.registerFunction("my_avg", IntegralAverage)
}

override def afterAll(): Unit = {
H2Dialect.clearFunctions()
Utils.deleteRecursively(tempDir)
super.afterAll()
}
Expand Down Expand Up @@ -1412,4 +1415,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}
}

test("register dialect specific functions") {
val df = sql("SELECT h2.my_avg(id) FROM h2.test.people")
checkAggregateRemoved(df, false)
checkAnswer(df, Row(1) :: Nil)
val e1 = intercept[AnalysisException] {
checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty)
}
assert(e1.getMessage.contains("Undefined function: h2.test.my_avg2"))
val e2 = intercept[AnalysisException] {
checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty)
}
assert(e2.getMessage.contains("Undefined function: h2.my_avg2"))
}
}