Skip to content

Commit f00df40

Browse files
zjffdumarmbrus
authored andcommitted
[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF
Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits: * Leverage the power of rich third party java library * Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance. Author: Jeff Zhang <zjffdu@apache.org> Closes #9766 from zjffdu/SPARK-11775.
1 parent 5aeb738 commit f00df40

File tree

5 files changed

+152
-4
lines changed

5 files changed

+152
-4
lines changed

python/pyspark/sql/context.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pyspark.sql.dataframe import DataFrame
2929
from pyspark.sql.readwriter import DataFrameReader
3030
from pyspark.sql.streaming import DataStreamReader
31-
from pyspark.sql.types import Row, StringType
31+
from pyspark.sql.types import IntegerType, Row, StringType
3232
from pyspark.sql.utils import install_exception_handler
3333

3434
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
@@ -202,6 +202,32 @@ def registerFunction(self, name, f, returnType=StringType()):
202202
"""
203203
self.sparkSession.catalog.registerFunction(name, f, returnType)
204204

205+
@ignore_unicode_prefix
206+
@since(2.1)
207+
def registerJavaFunction(self, name, javaClassName, returnType=None):
208+
"""Register a java UDF so it can be used in SQL statements.
209+
210+
In addition to a name and the function itself, the return type can be optionally specified.
211+
When the return type is not specified we would infer it via reflection.
212+
:param name: name of the UDF
213+
:param javaClassName: fully qualified name of java class
214+
:param returnType: a :class:`pyspark.sql.types.DataType` object
215+
216+
>>> sqlContext.registerJavaFunction("javaStringLength",
217+
... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
218+
>>> sqlContext.sql("SELECT javaStringLength('test')").collect()
219+
[Row(UDF(test)=4)]
220+
>>> sqlContext.registerJavaFunction("javaStringLength2",
221+
... "test.org.apache.spark.sql.JavaStringLength")
222+
>>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
223+
[Row(UDF(test)=4)]
224+
225+
"""
226+
jdt = None
227+
if returnType is not None:
228+
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
229+
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
230+
205231
# TODO(andrew): delete this once we refactor things to take in SparkSession
206232
def _inferSchema(self, rdd, samplingRatio=None):
207233
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object JavaTypeInference {
5959
* @param typeToken Java type
6060
* @return (SQL data type, nullable)
6161
*/
62-
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
62+
private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
6363
typeToken.getRawType match {
6464
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
6565
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,25 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.io.IOException
21+
import java.lang.reflect.{ParameterizedType, Type}
22+
2023
import scala.reflect.runtime.universe.TypeTag
2124
import scala.util.Try
2225

26+
import com.google.common.reflect.TypeToken
27+
2328
import org.apache.spark.annotation.InterfaceStability
2429
import org.apache.spark.internal.Logging
2530
import org.apache.spark.sql.api.java._
31+
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
2632
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
27-
import org.apache.spark.sql.catalyst.ScalaReflection
2833
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
2934
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
3035
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
3136
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
32-
import org.apache.spark.sql.types.DataType
37+
import org.apache.spark.sql.types.{DataType, DataTypes}
38+
import org.apache.spark.util.Utils
3339

3440
/**
3541
* Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
@@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
413419
//////////////////////////////////////////////////////////////////////////////////////////////
414420
//////////////////////////////////////////////////////////////////////////////////////////////
415421

422+
/**
423+
* Register a Java UDF class using reflection, for use from pyspark
424+
*
425+
* @param name udf name
426+
* @param className fully qualified class name of udf
427+
* @param returnDataType return type of udf. If it is null, spark would try to infer
428+
* via reflection.
429+
*/
430+
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
431+
432+
try {
433+
val clazz = Utils.classForName(className)
434+
val udfInterfaces = clazz.getGenericInterfaces
435+
.filter(_.isInstanceOf[ParameterizedType])
436+
.map(_.asInstanceOf[ParameterizedType])
437+
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
438+
if (udfInterfaces.length == 0) {
439+
throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
440+
} else if (udfInterfaces.length > 1) {
441+
throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
442+
} else {
443+
try {
444+
val udf = clazz.newInstance()
445+
val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
446+
var returnType = returnDataType
447+
if (returnType == null) {
448+
returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1
449+
}
450+
451+
udfInterfaces(0).getActualTypeArguments.length match {
452+
case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
453+
case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
454+
case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
455+
case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
456+
case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
457+
case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
458+
case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
459+
case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
460+
case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
461+
case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
462+
case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
463+
case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
464+
case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
465+
case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
466+
case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
467+
case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
468+
case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
469+
case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
470+
case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
471+
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
472+
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
473+
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
474+
case n => logError(s"UDF class with ${n} type arguments is not supported ")
475+
}
476+
} catch {
477+
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
478+
logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
479+
}
480+
}
481+
} catch {
482+
case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
483+
}
484+
485+
}
486+
416487
/**
417488
* Register a user-defined function with 1 arguments.
418489
* @since 1.3.0
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql;
19+
20+
import org.apache.spark.sql.api.java.UDF1;
21+
22+
/**
23+
* It is used for register Java UDF from PySpark
24+
*/
25+
public class JavaStringLength implements UDF1<String, Integer> {
26+
@Override
27+
public Integer call(String str) throws Exception {
28+
return new Integer(str.length());
29+
}
30+
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,25 @@ public Integer call(String str1, String str2) {
8787
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
8888
Assert.assertEquals(9, result.getInt(0));
8989
}
90+
91+
public static class StringLengthTest implements UDF2<String, String, Integer> {
92+
@Override
93+
public Integer call(String str1, String str2) throws Exception {
94+
return new Integer(str1.length() + str2.length());
95+
}
96+
}
97+
98+
@SuppressWarnings("unchecked")
99+
@Test
100+
public void udf3Test() {
101+
spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
102+
DataTypes.IntegerType);
103+
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
104+
Assert.assertEquals(9, result.getInt(0));
105+
106+
// returnType is not provided
107+
spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
108+
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
109+
Assert.assertEquals(9, result.getInt(0));
110+
}
90111
}

0 commit comments

Comments
 (0)