Skip to content

Commit 8601fce

Browse files
committed
Looks up temporary function using execution Hive client
1 parent e2957bc commit 8601fce

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,51 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
463463
assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
464464
}
465465
}
466+
467+
test("SPARK-11595 ADD JAR with input path having URL scheme") {
468+
withJdbcStatement { statement =>
469+
val jarPath = "../hive/src/test/resources/TestUDTF.jar"
470+
val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
471+
472+
Seq(
473+
s"ADD JAR $jarURL",
474+
s"""CREATE TEMPORARY FUNCTION udtf_count2
475+
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
476+
""".stripMargin
477+
).foreach(statement.execute)
478+
479+
val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
480+
481+
assert(rs1.next())
482+
assert(rs1.getString(1) === "Function: udtf_count2")
483+
484+
assert(rs1.next())
485+
assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
486+
rs1.getString(1)
487+
}
488+
489+
assert(rs1.next())
490+
assert(rs1.getString(1) === "Usage: To be added.")
491+
492+
val dataPath = "../hive/src/test/resources/data/files/kv1.txt"
493+
494+
Seq(
495+
s"CREATE TABLE test_udtf(key INT, value STRING)",
496+
s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
497+
).foreach(statement.execute)
498+
499+
val rs2 = statement.executeQuery(
500+
"SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc")
501+
502+
assert(rs2.next())
503+
assert(rs2.getInt(1) === 97)
504+
assert(rs2.getInt(2) === 500)
505+
506+
assert(rs2.next())
507+
assert(rs2.getInt(1) === 97)
508+
assert(rs2.getInt(2) === 500)
509+
}
510+
}
466511
}
467512

468513
class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ class HiveContext private[hive](
454454
// Note that HiveUDFs will be overridden by functions registered in this context.
455455
@transient
456456
override protected[sql] lazy val functionRegistry: FunctionRegistry =
457-
new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) {
457+
new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), this) {
458458
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
459459
// Hive Registry need current database to lookup function
460460
// TODO: the current database of executionHive should be consistent with metadataHive

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,23 @@ import org.apache.spark.sql.hive.HiveShim._
4646
import org.apache.spark.sql.types._
4747

4848

49-
private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
49+
private[hive] class HiveFunctionRegistry(
50+
underlying: analysis.FunctionRegistry,
51+
hiveContext: HiveContext)
5052
extends analysis.FunctionRegistry with HiveInspectors {
5153

52-
def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
54+
def getFunctionInfo(name: String): FunctionInfo = {
55+
hiveContext.executionHive.withHiveState {
56+
FunctionRegistry.getFunctionInfo(name)
57+
}
58+
}
5359

5460
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
5561
Try(underlying.lookupFunction(name, children)).getOrElse {
5662
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
5763
// not always serializable.
5864
val functionInfo: FunctionInfo =
59-
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
65+
Option(getFunctionInfo(name.toLowerCase)).getOrElse(
6066
throw new AnalysisException(s"undefined function $name"))
6167

6268
val functionClassName = functionInfo.getFunctionClass.getName
@@ -110,7 +116,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
110116
override def lookupFunction(name: String): Option[ExpressionInfo] = {
111117
underlying.lookupFunction(name).orElse(
112118
Try {
113-
val info = FunctionRegistry.getFunctionInfo(name)
119+
val info = getFunctionInfo(name)
114120
val annotation = info.getFunctionClass.getAnnotation(classOf[Description])
115121
if (annotation != null) {
116122
Some(new ExpressionInfo(

0 commit comments

Comments
 (0)