Skip to content

Commit c92e408

Browse files
imback82cloud-fan
authored andcommitted
[SPARK-34388][SQL] Propagate the registered UDF name to ScalaUDF, ScalaUDAF and ScalaAggregator
### What changes were proposed in this pull request? This PR proposes to propagate the name used for registering UDFs to `ScalaUDF`, `ScalaUDAF` and `ScaalAggregator`. Note that `PythonUDF` gets the name correctly: https://github.com/apache/spark/blob/466c045bfac20b6ce19f5a3732e76a5de4eb4e4a/python/pyspark/sql/udf.py#L358-L359 , and same for Hive UDFs: https://github.com/apache/spark/blob/466c045bfac20b6ce19f5a3732e76a5de4eb4e4a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala#L67 ### Why are the changes needed? This PR can help in the following scenarios: 1) Better EXPLAIN output 2) By adding `def name: String` to `UserDefinedExpression`, we can match an expression by `UserDefinedExpression` and look up the catalog, an use case needed for apache#31273. ### Does this PR introduce _any_ user-facing change? The EXPLAIN output involving udfs will be changed to use the name used for UDF registration. For example, for the following: ``` sql("CREATE TEMPORARY FUNCTION test_udf AS 'org.apache.spark.examples.sql.Spark33084'") sql("SELECT test_udf(col1) FROM VALUES (1), (2), (3)").explain(true) ``` The output of the optimized plan will change from: ``` Aggregate [spark33084(cast(col1#223 as bigint), org.apache.spark.examples.sql.Spark330846906be0f, 1, 1) AS spark33084(col1)apache#237] +- LocalRelation [col1#223] ``` to ``` Aggregate [test_udf(cast(col1#223 as bigint), org.apache.spark.examples.sql.Spark330847a62d697, 1, 1, Some(test_udf)) AS test_udf(col1)apache#237] +- LocalRelation [col1#223] ``` ### How was this patch tested? Added new tests. Closes apache#31500 from imback82/udaf_name. Authored-by: Terry Kim <yuminkim@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent d1131bc commit c92e408

File tree

8 files changed

+108
-38
lines changed

8 files changed

+108
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,9 +1409,14 @@ class SessionCatalog(
14091409
Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction")
14101410
if (clsForUDAF.isAssignableFrom(clazz)) {
14111411
val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
1412-
val e = cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
1413-
.newInstance(input,
1414-
clazz.getConstructor().newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
1412+
val e = cls.getConstructor(
1413+
classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int], classOf[Option[String]])
1414+
.newInstance(
1415+
input,
1416+
clazz.getConstructor().newInstance().asInstanceOf[Object],
1417+
Int.box(1),
1418+
Int.box(1),
1419+
Some(name))
14151420
.asInstanceOf[ImplicitCastInputTypes]
14161421

14171422
// Check input argument size

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,4 +1088,6 @@ trait ComplexTypeMergingExpression extends Expression {
10881088
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
10891089
* and Hive function wrappers.
10901090
*/
1091-
trait UserDefinedExpression
1091+
trait UserDefinedExpression {
1092+
def name: String
1093+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ case class ScalaUDF(
5757

5858
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
5959

60-
override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"
60+
override def toString: String = s"$name(${children.mkString(", ")})"
61+
62+
override def name: String = udfName.getOrElse("UDF")
6163

6264
override lazy val canonicalized: Expression = {
6365
// SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
8282
@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
8383
" via the functions.udaf(agg) method.", "3.0.0")
8484
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
85-
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
85+
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf, udafName = Some(name))
8686
functionRegistry.createOrReplaceTempFunction(name, builder)
8787
udaf
8888
}
@@ -109,15 +109,15 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
109109
* @since 2.2.0
110110
*/
111111
def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
112-
udf match {
112+
udf.withName(name) match {
113113
case udaf: UserDefinedAggregator[_, _, _] =>
114114
def builder(children: Seq[Expression]) = udaf.scalaAggregator(children)
115115
functionRegistry.createOrReplaceTempFunction(name, builder)
116-
udf
117-
case _ =>
118-
def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
116+
udaf
117+
case other =>
118+
def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr
119119
functionRegistry.createOrReplaceTempFunction(name, builder)
120-
udf
120+
other
121121
}
122122
}
123123

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ case class ScalaUDAF(
325325
children: Seq[Expression],
326326
udaf: UserDefinedAggregateFunction,
327327
mutableAggBufferOffset: Int = 0,
328-
inputAggBufferOffset: Int = 0)
328+
inputAggBufferOffset: Int = 0,
329+
udafName: Option[String] = None)
329330
extends ImperativeAggregate
330331
with NonSQLExpression
331332
with Logging
@@ -447,10 +448,12 @@ case class ScalaUDAF(
447448
}
448449

449450
override def toString: String = {
450-
s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
451+
s"""$nodeName(${children.mkString(",")})"""
451452
}
452453

453-
override def nodeName: String = udaf.getClass.getSimpleName
454+
override def nodeName: String = name
455+
456+
override def name: String = udafName.getOrElse(udaf.getClass.getSimpleName)
454457
}
455458

456459
case class ScalaAggregator[IN, BUF, OUT](
@@ -461,7 +464,8 @@ case class ScalaAggregator[IN, BUF, OUT](
461464
nullable: Boolean = true,
462465
isDeterministic: Boolean = true,
463466
mutableAggBufferOffset: Int = 0,
464-
inputAggBufferOffset: Int = 0)
467+
inputAggBufferOffset: Int = 0,
468+
aggregatorName: Option[String] = None)
465469
extends TypedImperativeAggregate[BUF]
466470
with NonSQLExpression
467471
with UserDefinedExpression
@@ -513,7 +517,9 @@ case class ScalaAggregator[IN, BUF, OUT](
513517

514518
override def toString: String = s"""${nodeName}(${children.mkString(",")})"""
515519

516-
override def nodeName: String = agg.getClass.getSimpleName
520+
override def nodeName: String = name
521+
522+
override def name: String = aggregatorName.getOrElse(agg.getClass.getSimpleName)
517523
}
518524

519525
/**

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
150150
def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = {
151151
val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
152152
val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]]
153-
ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic)
153+
ScalaAggregator(
154+
exprs, aggregator, iEncoder, bEncoder, nullable, deterministic, aggregatorName = name)
154155
}
155156

156157
override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -321,26 +321,34 @@ object IntegratedUDFTestUtils extends SQLHelper {
321321
* casted_col.cast(df.schema("col").dataType)
322322
* }}}
323323
*/
324-
case class TestScalaUDF(name: String) extends TestUDF {
325-
private[IntegratedUDFTestUtils] lazy val udf = new SparkUserDefinedFunction(
326-
(input: Any) => if (input == null) {
327-
null
328-
} else {
329-
input.toString
330-
},
331-
StringType,
332-
inputEncoders = Seq.fill(1)(None),
333-
name = Some(name)) {
334-
335-
override def apply(exprs: Column*): Column = {
336-
assert(exprs.length == 1, "Defined UDF only has one column")
337-
val expr = exprs.head.expr
338-
assert(expr.resolved, "column should be resolved to use the same type " +
339-
"as input. Try df(name) or df.col(name)")
340-
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
341-
}
324+
class TestInternalScalaUDF(name: String) extends SparkUserDefinedFunction(
325+
(input: Any) => if (input == null) {
326+
null
327+
} else {
328+
input.toString
329+
},
330+
StringType,
331+
inputEncoders = Seq.fill(1)(None),
332+
name = Some(name)) {
333+
334+
override def apply(exprs: Column*): Column = {
335+
assert(exprs.length == 1, "Defined UDF only has one column")
336+
val expr = exprs.head.expr
337+
assert(expr.resolved, "column should be resolved to use the same type " +
338+
"as input. Try df(name) or df.col(name)")
339+
Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), expr.dataType))
342340
}
343341

342+
override def withName(name: String): TestInternalScalaUDF = {
343+
// "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object
344+
// is sliced and the overridden "apply" is not invoked.
345+
new TestInternalScalaUDF(name)
346+
}
347+
}
348+
349+
case class TestScalaUDF(name: String) extends TestUDF {
350+
private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name)
351+
344352
def apply(exprs: Column*): Column = udf(exprs: _*)
345353

346354
val prettyName: String = "Scala UDF"

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ import scala.collection.mutable.{ArrayBuffer, WrappedArray}
2626

2727
import org.apache.spark.SparkException
2828
import org.apache.spark.sql.api.java._
29-
import org.apache.spark.sql.catalyst.encoders.OuterScopes
29+
import org.apache.spark.sql.catalyst.FunctionIdentifier
30+
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, OuterScopes}
31+
import org.apache.spark.sql.catalyst.expressions.{Literal, ScalaUDF}
3032
import org.apache.spark.sql.catalyst.plans.logical.Project
3133
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3234
import org.apache.spark.sql.execution.{QueryExecution, SimpleMode}
35+
import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF}
3336
import org.apache.spark.sql.execution.columnar.InMemoryRelation
3437
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand}
3538
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
36-
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
37-
import org.apache.spark.sql.functions.{lit, struct, udf}
39+
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, SparkUserDefinedFunction, UserDefinedAggregateFunction}
40+
import org.apache.spark.sql.functions.{lit, struct, udaf, udf}
3841
import org.apache.spark.sql.internal.SQLConf
3942
import org.apache.spark.sql.test.SharedSparkSession
4043
import org.apache.spark.sql.test.SQLTestData._
@@ -798,4 +801,47 @@ class UDFSuite extends QueryTest with SharedSparkSession {
798801
.select(myUdf(Column("col"))),
799802
Row(ArrayBuffer(100)))
800803
}
804+
805+
test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") {
806+
spark.udf.register("udf34388", udf((value: Int) => value > 2))
807+
spark.sessionState.catalog.lookupFunction(
808+
FunctionIdentifier("udf34388"), Seq(Literal(1))) match {
809+
case udf: ScalaUDF => assert(udf.name === "udf34388")
810+
}
811+
}
812+
813+
test("SPARK-34388: UDF name is propagated with registration for ScalaAggregator") {
814+
val agg = new Aggregator[Long, Long, Long] {
815+
override def zero: Long = 0L
816+
override def reduce(b: Long, a: Long): Long = a + b
817+
override def merge(b1: Long, b2: Long): Long = b1 + b2
818+
override def finish(reduction: Long): Long = reduction
819+
override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
820+
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
821+
}
822+
823+
spark.udf.register("agg34388", udaf(agg))
824+
spark.sessionState.catalog.lookupFunction(
825+
FunctionIdentifier("agg34388"), Seq(Literal(1))) match {
826+
case agg: ScalaAggregator[_, _, _] => assert(agg.name === "agg34388")
827+
}
828+
}
829+
830+
test("SPARK-34388: UDF name is propagated with registration for ScalaUDAF") {
831+
val udaf = new UserDefinedAggregateFunction {
832+
def inputSchema: StructType = new StructType().add("a", LongType)
833+
def bufferSchema: StructType = new StructType().add("product", LongType)
834+
def dataType: DataType = LongType
835+
def deterministic: Boolean = true
836+
def initialize(buffer: MutableAggregationBuffer): Unit = {}
837+
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {}
838+
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
839+
def evaluate(buffer: Row): Any = buffer.getLong(0)
840+
}
841+
spark.udf.register("udaf34388", udaf)
842+
spark.sessionState.catalog.lookupFunction(
843+
FunctionIdentifier("udaf34388"), Seq(Literal(1))) match {
844+
case udaf: ScalaUDAF => assert(udaf.name === "udaf34388")
845+
}
846+
}
801847
}

0 commit comments

Comments
 (0)