Skip to content

Commit 9fe30cf

Browse files
author
Davies Liu
committed
refactor
1 parent bcaddb3 commit 9fe30cf

File tree

7 files changed

+46
-65
lines changed

7 files changed

+46
-65
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,37 +44,47 @@ trait FunctionRegistry {
4444

4545
/* Get the class of the registered function by specified name. */
4646
def lookupFunction(name: String): Option[ExpressionInfo]
47-
48-
def copy(): FunctionRegistry
4947
}
5048

5149
class SimpleFunctionRegistry extends FunctionRegistry {
5250

5351
private val functionBuilders =
5452
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
5553

56-
override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder)
57-
: Unit = {
58-
functionBuilders.put(name, (info, builder))
54+
override def registerFunction(
55+
name: String,
56+
info: ExpressionInfo,
57+
builder: FunctionBuilder): Unit = {
58+
synchronized {
59+
functionBuilders.put(name, (info, builder))
60+
}
5961
}
6062

6163
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
62-
val func = functionBuilders.get(name).map(_._2).getOrElse {
63-
throw new AnalysisException(s"undefined function $name")
64+
val func = synchronized {
65+
functionBuilders.get(name).map(_._2).getOrElse {
66+
throw new AnalysisException(s"undefined function $name")
67+
}
6468
}
6569
func(children)
6670
}
6771

68-
override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted
72+
override def listFunction(): Seq[String] = synchronized {
73+
functionBuilders.iterator.map(_._1).toList.sorted
74+
}
6975

7076
override def lookupFunction(name: String): Option[ExpressionInfo] = {
71-
functionBuilders.get(name).map(_._1)
77+
synchronized {
78+
functionBuilders.get(name).map(_._1)
79+
}
7280
}
7381

74-
override def copy(): SimpleFunctionRegistry = {
82+
def copy(): SimpleFunctionRegistry = {
7583
val registry = new SimpleFunctionRegistry
76-
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
84+
synchronized {
85+
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
7786
registry.registerFunction(name, info, builder)
87+
}
7888
}
7989
registry
8090
}
@@ -101,10 +111,6 @@ object EmptyFunctionRegistry extends FunctionRegistry {
101111
override def lookupFunction(name: String): Option[ExpressionInfo] = {
102112
throw new UnsupportedOperationException
103113
}
104-
105-
override def copy(): FunctionRegistry = {
106-
this
107-
}
108114
}
109115

110116

@@ -270,7 +276,7 @@ object FunctionRegistry {
270276
expression[InputFileName]("input_file_name")
271277
)
272278

273-
val builtin: FunctionRegistry = {
279+
val builtin: SimpleFunctionRegistry = {
274280
val fr = new SimpleFunctionRegistry
275281
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) }
276282
fr

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,17 @@ import org.apache.spark.util.Utils
6464
class SQLContext private[sql](
6565
@transient val sparkContext: SparkContext,
6666
@transient protected[sql] val cacheManager: CacheManager)
67-
extends org.apache.spark.Logging
68-
with Serializable {
67+
extends org.apache.spark.Logging with Serializable {
6968

7069
self =>
7170

7271
def this(sparkContext: SparkContext) = this(sparkContext, new CacheManager)
7372
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
7473

74+
/**
75+
* Returns a SQLContext as new session, with separated SQL configurations, temporary tables,
76+
* registered functions, but share the same SparkContext and CacheManager.
77+
*/
7578
def newSession(): SQLContext = {
7679
new SQLContext(sparkContext, cacheManager)
7780
}
@@ -207,6 +210,9 @@ class SQLContext private[sql](
207210
conf.dialect
208211
}
209212

213+
/**
214+
* Add a jar to SQLContext
215+
*/
210216
protected[sql] def addJar(path: String): Unit = {
211217
sparkContext.addJar(path)
212218
}
@@ -1230,7 +1236,7 @@ object SQLContext {
12301236
}
12311237

12321238
/**
1233-
* Clear the SQLContext for current thread
1239+
* Clear the active SQLContext for current thread
12341240
*/
12351241
def clearActive(): Unit = {
12361242
activeContexts.remove()

sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import scala.collection.JavaConverters._
2626
import scala.collection.mutable.{ArrayBuffer, Map => SMap}
2727
import scala.util.control.NonFatal
2828

29-
import org.apache.hadoop.hive.conf.HiveConf
3029
import org.apache.hadoop.hive.metastore.api.FieldSchema
3130
import org.apache.hadoop.hive.shims.Utils
3231
import org.apache.hive.service.cli._
@@ -140,24 +139,22 @@ private[hive] class SparkExecuteStatementOperation(
140139
if (!runInBackground) {
141140
runInternal()
142141
} else {
143-
val hiveConf = getConfigForOperation()
144142
val sparkServiceUGI = Utils.getUGI()
145143

146144
// Runnable impl to call runInternal asynchronously,
147145
// from a different thread
148146
val backgroundOperation = new Runnable() {
149147

150148
override def run(): Unit = {
151-
val doAsAction = new PrivilegedExceptionAction[Object]() {
152-
override def run(): Object = {
149+
val doAsAction = new PrivilegedExceptionAction[Unit]() {
150+
override def run(): Unit = {
153151
try {
154152
runInternal()
155153
} catch {
156154
case e: HiveSQLException =>
157155
setOperationException(e)
158156
log.error("Error running hive query: ", e)
159157
}
160-
return null
161158
}
162159
}
163160

@@ -174,7 +171,7 @@ private[hive] class SparkExecuteStatementOperation(
174171
try {
175172
// This submit blocks if no background threads are available to run this operation
176173
val backgroundHandle =
177-
getParentSession().getSessionManager().submitBackgroundOperation(backgroundOperation)
174+
parentSession.getSessionManager().submitBackgroundOperation(backgroundOperation)
178175
setBackgroundHandle(backgroundHandle)
179176
} catch {
180177
case rejected: RejectedExecutionException =>
@@ -193,6 +190,9 @@ private[hive] class SparkExecuteStatementOperation(
193190
statementId = UUID.randomUUID().toString
194191
logInfo(s"Running query '$statement' with $statementId")
195192
setState(OperationState.RUNNING)
193+
val executionHiveClassLoader =
194+
hiveContext.executionHive.state.getConf.getClassLoader
195+
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)
196196
HiveThriftServer2.listener.onStatementStart(
197197
statementId,
198198
parentSession.getSessionHandle.getSessionId.toString,
@@ -262,34 +262,4 @@ private[hive] class SparkExecuteStatementOperation(
262262
}
263263
}
264264
}
265-
266-
/**
267-
* If there are query specific settings to overlay, then create a copy of config
268-
* There are two cases we need to clone the session config that's being passed to hive driver
269-
* 1. Async query -
270-
* If the client changes a config setting, that shouldn't reflect in the execution
271-
* already underway
272-
* 2. confOverlay -
273-
* The query specific settings should only be applied to the query config and not session
274-
* @return new configuration
275-
* @throws HiveSQLException
276-
*/
277-
private def getConfigForOperation(): HiveConf = {
278-
var sqlOperationConf = getParentSession().getHiveConf()
279-
if (!getConfOverlay().isEmpty() || runInBackground) {
280-
// clone the partent session config for this query
281-
sqlOperationConf = new HiveConf(sqlOperationConf)
282-
283-
// apply overlay query specific settings, if any
284-
getConfOverlay().asScala.foreach { case (k, v) =>
285-
try {
286-
sqlOperationConf.verifyAndSet(k, v)
287-
} catch {
288-
case e: IllegalArgumentException =>
289-
throw new HiveSQLException("Error applying statement specific settings", e)
290-
}
291-
}
292-
}
293-
return sqlOperationConf
294-
}
295265
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ case class HiveTableScan(
130130
}
131131

132132
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
133-
sqlContext.asInstanceOf[HiveContext].executionHive.withHiveState {
134-
hadoopReader.makeRDDForTable(relation.hiveQlTable)
135-
}
133+
hadoopReader.makeRDDForTable(relation.hiveQlTable)
136134
} else {
137135
hadoopReader.makeRDDForPartitionedTable(
138136
prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,6 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
103103
}
104104
}.getOrElse(None))
105105
}
106-
107-
override def copy(): HiveFunctionRegistry = {
108-
this
109-
}
110106
}
111107

112108
private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
11331133
conf.clear()
11341134
}
11351135

1136-
// Enable this test once fix the current_database()
1136+
// TODO: Enable this test once fix SPARK-10902
11371137
ignore("current_database with mutiple sessions") {
11381138
sql("create database a")
11391139
sql("use a")

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
160160
}
161161

162162
test("show functions") {
163-
val allFunctions =
163+
val allBuiltinFunctions =
164164
(FunctionRegistry.builtin.listFunction().toSet[String] ++
165165
org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted
166-
checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_)))
166+
// The TestContext is shared by all the test cases, some functions may be registered before
167+
// this, so we check that all the builtin functions are returned.
168+
val allFunctions = sql("SHOW functions").collect().map(r => r(0))
169+
allBuiltinFunctions.foreach { f =>
170+
assert(allFunctions.contains(f))
171+
}
167172
checkAnswer(sql("SHOW functions abs"), Row("abs"))
168173
checkAnswer(sql("SHOW functions 'abs'"), Row("abs"))
169174
checkAnswer(sql("SHOW functions abc.abs"), Row("abs"))

0 commit comments

Comments
 (0)