Skip to content

Commit b767cee

Browse files
lianchengyhuai
authored andcommitted
[SPARK-11191][SPARK-11311][SQL] Backports #9664 and #9277 to branch-1.5
The main purpose of this PR is to backport #9664, which depends on #9277. Author: Cheng Lian <lian@databricks.com> Closes #9671 from liancheng/spark-11191.fix-temp-function.branch-1_5.
1 parent 330961b commit b767cee

File tree

4 files changed

+77
-16
lines changed

4 files changed

+77
-16
lines changed

sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
256256
{ statement =>
257257

258258
val queries = Seq(
259-
s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291",
260-
"SET hive.cli.print.header=true"
261-
)
259+
s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291",
260+
"SET hive.cli.print.header=true"
261+
)
262262

263263
queries.map(statement.execute)
264264
val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}")
@@ -458,6 +458,53 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
458458
assert(conf.get("spark.sql.hive.version") === Some("1.2.1"))
459459
}
460460
}
461+
462+
test("SPARK-11595 ADD JAR with input path having URL scheme") {
463+
withJdbcStatement { statement =>
464+
statement.executeQuery("SET spark.sql.hive.thriftServer.async=true")
465+
466+
val jarPath = "../hive/src/test/resources/TestUDTF.jar"
467+
val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
468+
469+
Seq(
470+
s"ADD JAR $jarURL",
471+
s"""CREATE TEMPORARY FUNCTION udtf_count2
472+
|AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
473+
""".stripMargin
474+
).foreach(statement.execute)
475+
476+
val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
477+
478+
assert(rs1.next())
479+
assert(rs1.getString(1) === "Function: udtf_count2")
480+
481+
assert(rs1.next())
482+
assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
483+
rs1.getString(1)
484+
}
485+
486+
assert(rs1.next())
487+
assert(rs1.getString(1) === "Usage: To be added.")
488+
489+
val dataPath = "../hive/src/test/resources/data/files/kv1.txt"
490+
491+
Seq(
492+
s"CREATE TABLE test_udtf(key INT, value STRING)",
493+
s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
494+
).foreach(statement.execute)
495+
496+
val rs2 = statement.executeQuery(
497+
"SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc")
498+
499+
assert(rs2.next())
500+
assert(rs2.getInt(1) === 97)
501+
assert(rs2.getInt(2) === 500)
502+
503+
assert(rs2.next())
504+
assert(rs2.getInt(1) === 97)
505+
assert(rs2.getInt(2) === 500)
506+
}
507+
}
461508
}
462509

463510
class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ import org.apache.spark.util.Utils
5757
/**
5858
* This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
5959
*/
60-
private[hive] class HiveQLDialect extends ParserDialect {
60+
private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect {
6161
override def parse(sqlText: String): LogicalPlan = {
62-
HiveQl.parseSql(sqlText)
62+
sqlContext.executionHive.withHiveState {
63+
HiveQl.parseSql(sqlText)
64+
}
6365
}
6466
}
6567

@@ -410,7 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
410412
// Note that HiveUDFs will be overridden by functions registered in this context.
411413
@transient
412414
override protected[sql] lazy val functionRegistry: FunctionRegistry =
413-
new HiveFunctionRegistry(FunctionRegistry.builtin)
415+
new HiveFunctionRegistry(FunctionRegistry.builtin, this)
414416

415417
/* An analyzer that uses the Hive metastore. */
416418
@transient
@@ -517,10 +519,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
517519
}
518520
}
519521

520-
override protected[sql] def dialectClassName = if (conf.dialect == "hiveql") {
521-
classOf[HiveQLDialect].getCanonicalName
522-
} else {
523-
super.dialectClassName
522+
protected[sql] override def getSQLDialect(): ParserDialect = {
523+
if (conf.dialect == "hiveql") {
524+
new HiveQLDialect(this)
525+
} else {
526+
super.getSQLDialect()
527+
}
524528
}
525529

526530
@transient

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ private[hive] class ClientWrapper(
245245
/**
246246
* Runs `f` with ThreadLocal session state and classloaders configured for this version of hive.
247247
*/
248-
private def withHiveState[A](f: => A): A = retryLocked {
248+
def withHiveState[A](f: => A): A = retryLocked {
249249
val original = Thread.currentThread().getContextClassLoader
250250
// Set the thread local metastore client to the client associated with this ClientWrapper.
251251
Hive.set(client)

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,23 @@ import org.apache.spark.sql.hive.HiveShim._
4444
import org.apache.spark.sql.types._
4545

4646

47-
private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
47+
private[hive] class HiveFunctionRegistry(
48+
underlying: analysis.FunctionRegistry,
49+
hiveContext: HiveContext)
4850
extends analysis.FunctionRegistry with HiveInspectors {
4951

50-
def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
52+
def getFunctionInfo(name: String): FunctionInfo = {
53+
hiveContext.executionHive.withHiveState {
54+
FunctionRegistry.getFunctionInfo(name)
55+
}
56+
}
5157

5258
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
5359
Try(underlying.lookupFunction(name, children)).getOrElse {
5460
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
5561
// not always serializable.
5662
val functionInfo: FunctionInfo =
57-
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
63+
Option(getFunctionInfo(name.toLowerCase)).getOrElse(
5864
throw new AnalysisException(s"undefined function $name"))
5965

6066
val functionClassName = functionInfo.getFunctionClass.getName
@@ -89,7 +95,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
8995
override def lookupFunction(name: String): Option[ExpressionInfo] = {
9096
underlying.lookupFunction(name).orElse(
9197
Try {
92-
val info = FunctionRegistry.getFunctionInfo(name)
98+
val info = getFunctionInfo(name)
9399
val annotation = info.getFunctionClass.getAnnotation(classOf[Description])
94100
if (annotation != null) {
95101
Some(new ExpressionInfo(
@@ -98,7 +104,11 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
98104
annotation.value(),
99105
annotation.extended()))
100106
} else {
101-
None
107+
Some(new ExpressionInfo(
108+
info.getFunctionClass.getCanonicalName,
109+
name,
110+
null,
111+
null))
102112
}
103113
}.getOrElse(None))
104114
}

0 commit comments

Comments
 (0)