|
17 | 17 |
|
18 | 18 | package org.apache.spark.sql
|
19 | 19 |
|
| 20 | +import java.io.IOException |
| 21 | +import java.lang.reflect.{ParameterizedType, Type} |
| 22 | + |
20 | 23 | import scala.reflect.runtime.universe.TypeTag
|
21 | 24 | import scala.util.Try
|
22 | 25 |
|
| 26 | +import com.google.common.reflect.TypeToken |
| 27 | + |
23 | 28 | import org.apache.spark.annotation.InterfaceStability
|
24 | 29 | import org.apache.spark.internal.Logging
|
25 | 30 | import org.apache.spark.sql.api.java._
|
| 31 | +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} |
26 | 32 | import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
27 |
| -import org.apache.spark.sql.catalyst.ScalaReflection |
28 | 33 | import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
|
29 | 34 | import org.apache.spark.sql.execution.aggregate.ScalaUDAF
|
30 | 35 | import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
|
31 | 36 | import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
|
32 |
| -import org.apache.spark.sql.types.DataType |
| 37 | +import org.apache.spark.sql.types.{DataType, DataTypes} |
| 38 | +import org.apache.spark.util.Utils |
33 | 39 |
|
34 | 40 | /**
|
35 | 41 | * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
|
@@ -413,6 +419,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|
413 | 419 | //////////////////////////////////////////////////////////////////////////////////////////////
|
414 | 420 | //////////////////////////////////////////////////////////////////////////////////////////////
|
415 | 421 |
|
| 422 | + /** |
| 423 | + * Register a Java UDF class using reflection, for use from pyspark |
| 424 | + * |
| 425 | + * @param name udf name |
| 426 | + * @param className fully qualified class name of udf |
| 427 | + * @param returnDataType return type of udf. If it is null, spark would try to infer |
| 428 | + * via reflection. |
| 429 | + */ |
| 430 | + private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { |
| 431 | + |
| 432 | + try { |
| 433 | + val clazz = Utils.classForName(className) |
| 434 | + val udfInterfaces = clazz.getGenericInterfaces |
| 435 | + .filter(_.isInstanceOf[ParameterizedType]) |
| 436 | + .map(_.asInstanceOf[ParameterizedType]) |
| 437 | + .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) |
| 438 | + if (udfInterfaces.length == 0) { |
| 439 | + throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") |
| 440 | + } else if (udfInterfaces.length > 1) { |
| 441 | + throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") |
| 442 | + } else { |
| 443 | + try { |
| 444 | + val udf = clazz.newInstance() |
| 445 | + val udfReturnType = udfInterfaces(0).getActualTypeArguments.last |
| 446 | + var returnType = returnDataType |
| 447 | + if (returnType == null) { |
| 448 | + returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1 |
| 449 | + } |
| 450 | + |
| 451 | + udfInterfaces(0).getActualTypeArguments.length match { |
| 452 | + case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) |
| 453 | + case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) |
| 454 | + case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) |
| 455 | + case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) |
| 456 | + case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) |
| 457 | + case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) |
| 458 | + case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) |
| 459 | + case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) |
| 460 | + case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) |
| 461 | + case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 462 | + case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 463 | + case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 464 | + case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 465 | + case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 466 | + case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 467 | + case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 468 | + case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 469 | + case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 470 | + case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 471 | + case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 472 | + case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 473 | + case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) |
| 474 | + case n => logError(s"UDF class with ${n} type arguments is not supported ") |
| 475 | + } |
| 476 | + } catch { |
| 477 | + case e @ (_: InstantiationException | _: IllegalArgumentException) => |
| 478 | + logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") |
| 479 | + } |
| 480 | + } |
| 481 | + } catch { |
| 482 | + case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") |
| 483 | + } |
| 484 | + |
| 485 | + } |
| 486 | + |
416 | 487 | /**
|
417 | 488 | * Register a user-defined function with 1 arguments.
|
418 | 489 | * @since 1.3.0
|
|
0 commit comments