Skip to content
18 changes: 12 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ private[spark] object SQLConf {
*
* SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads).
*/
private[sql] class SQLConf extends Serializable with CatalystConf {
private[sql] class SQLConf(val config: Map[String, String] = Map.empty)
extends Serializable with CatalystConf {
import SQLConf._

/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
Expand Down Expand Up @@ -556,7 +557,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf {

/** Return the value of Spark SQL configuration property for the given key. */
def getConfString(key: String): String = {
Option(settings.get(key)).
getValue(key).
orElse {
// Try to use the default value
Option(sqlConfEntries.get(key)).map(_.defaultValueString)
Expand All @@ -571,7 +572,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
*/
def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = {
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue)
getValue(entry.key).map(entry.valueConverter).getOrElse(defaultValue)
}

/**
Expand All @@ -580,10 +581,15 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
*/
def getConf[T](entry: SQLConfEntry[T]): T = {
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue).
getValue(entry.key).map(entry.valueConverter).orElse(entry.defaultValue).
getOrElse(throw new NoSuchElementException(entry.key))
}

private def getValue(key: String): Option[String] = {
val conf = Option(settings.get(key))
if (!conf.isDefined) config.get(key) else conf
}

/**
* Return the `string` value of Spark SQL configuration property for the given key. If the key is
* not set yet, return `defaultValue`.
Expand All @@ -594,15 +600,15 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
// Only verify configs in the SQLConf object
entry.valueConverter(defaultValue)
}
Option(settings.get(key)).getOrElse(defaultValue)
getValue(key).getOrElse(defaultValue)
}

/**
* Return all the configuration properties that have been set (i.e. not the default).
* This creates a new copy of the config properties in the form of a Map.
*/
def getAllConfs: immutable.Map[String, String] =
settings.synchronized { settings.asScala.toMap }
settings.synchronized { config ++ settings.asScala.toMap }

/**
* Return all the configuration definitions that have been defined in [[SQLConf]]. Each
Expand Down
34 changes: 15 additions & 19 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,21 @@ import org.apache.spark.util.Utils
*
* @since 1.0.0
*/
class SQLContext(@transient val sparkContext: SparkContext)
class SQLContext(
@transient val sparkContext: SparkContext,
@transient optionConf: Map[String, String] = Map.empty)
extends org.apache.spark.Logging
with Serializable {

self =>

// for java invocation
def this(sparkContext: SparkContext) = this(sparkContext, Map.empty)

def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)

protected[sql] def defaultOverrides(): Map[String, String] = Map.empty

/**
* @return Spark SQL configuration
*/
Expand Down Expand Up @@ -198,36 +205,27 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

@transient
protected[sql] val defaultSession = createSession()
protected[sql] lazy val defaultSession = openSession()

protected[sql] def dialectClassName = if (conf.dialect == "sql") {
classOf[DefaultParserDialect].getCanonicalName
} else {
conf.dialect
}

{
protected[sql] val properties = {
// We extract spark sql settings from SparkContext's conf and put them to
// Spark SQL's conf.
// First, we populate the SQLConf (conf). So, we can make sure that other values using
// those settings in their construction can get the correct settings.
// For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version
// and spark.sql.hive.metastore.jars to get correctly constructed.
val properties = new Properties
sparkContext.getConf.getAll.foreach {
(sparkContext.getConf.getAll ++ optionConf).foreach {
case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value)
case _ =>
}
// We directly put those settings to conf to avoid of calling setConf, which may have
// side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive
// get constructed. If we call setConf directly, the constructed metadataHive may have
// wrong settings, or the construction may fail.
conf.setConf(properties)
// After we have populated SQLConf, we call setConf to populate other confs in the subclass
// (e.g. hiveconf in HiveContext).
properties.asScala.foreach {
case (key, value) => setConf(key, value)
}
properties
}

@transient
Expand Down Expand Up @@ -869,10 +867,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

protected[sql] def openSession(): SQLSession = {
detachSession()
val session = createSession()
tlSession.set(session)

session.conf.setConf(properties)
setSession(session)
session
}

Expand All @@ -889,13 +886,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
}

protected[sql] def setSession(session: SQLSession): Unit = {
detachSession()
tlSession.set(session)
}

protected[sql] class SQLSession {
// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
protected[sql] lazy val conf: SQLConf = new SQLConf(defaultOverrides())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext}
import org.apache.spark.sql.test.SharedSQLContext


class SQLConfSuite extends QueryTest with SharedSQLContext {
Expand All @@ -37,7 +37,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
// Clear the conf.
sqlContext.conf.clear()
// After clear, only overrideConfs used by unit test should be in the SQLConf.
assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs)
assert(sqlContext.getAllConfs === sqlContext.defaultOverrides())

sqlContext.setConf(testKey, testVal)
assert(sqlContext.getConf(testKey) === testVal)
Expand Down
11 changes: 6 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._

/** A SQL Dialect for testing purpose, and it can not be nested type */
Expand Down Expand Up @@ -1025,16 +1025,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val nonexistentKey = "nonexistent"

// "set" itself returns all config variables currently specified in SQLConf.
assert(sql("SET").collect().size === TestSQLContext.overrideConfs.size)
val overrides = sqlContext.defaultOverrides()
assert(sql("SET").collect().size === overrides.size)
sql("SET").collect().foreach { row =>
val key = row.getString(0)
val value = row.getString(1)
assert(
TestSQLContext.overrideConfs.contains(key),
overrides.contains(key),
s"$key should exist in SQLConf.")
assert(
TestSQLContext.overrideConfs(key) === value,
s"The value of $key should be ${TestSQLContext.overrideConfs(key)} instead of $value.")
overrides(key) === value,
s"The value of $key should be ${overrides(key)} instead of $value.")
}
val overrideConfs = sql("SET").collect()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,8 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
new SparkConf().set("spark.sql.testkey", "true")))
}

// Make sure we set those test specific confs correctly when we create
// the SQLConf as well as when we call clear.
protected[sql] override def createSession(): SQLSession = new this.SQLSession()

/** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
protected[sql] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {

clear()

override def clear(): Unit = {
super.clear()

// Make sure we start with the default test configs even after clear
TestSQLContext.overrideConfs.map {
case (key, value) => setConfString(key, value)
}
}
}
}
protected[sql] override def defaultOverrides(): Map[String, String] =
super.defaultOverrides ++ Seq(SQLConf.SHUFFLE_PARTITIONS.key -> "5").toMap

// Needed for Java tests
def loadTestData(): Unit = {
Expand All @@ -61,14 +43,3 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
protected override def sqlContext: SQLContext = self
}
}

private[sql] object TestSQLContext {

/**
* A map used to store all confs that need to be overridden in sql/core unit tests.
*/
val overrideConfs: Map[String, String] =
Map(
// Fewer shuffle partitions to speed up testing.
SQLConf.SHUFFLE_PARTITIONS.key -> "5")
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.util.ShutdownHookManager
import org.apache.spark.{Logging, SparkContext}


Expand Down Expand Up @@ -82,11 +82,13 @@ object HiveThriftServer2 extends Logging {
}

try {
val server = new HiveThriftServer2(SparkSQLEnv.hiveContext)
server.init(SparkSQLEnv.hiveContext.hiveconf)
server.start()
val hiveContext = SparkSQLEnv.hiveContext
val server = new HiveThriftServer2(hiveContext)
server.init(hiveContext.hiveconf)
hiveContext.execute(hiveContext.hiveconf, server.start())

logInfo("HiveThriftServer2 started")
listener = new HiveThriftServer2Listener(server, SparkSQLEnv.hiveContext.conf)
listener = new HiveThriftServer2Listener(server, hiveContext.conf)
SparkSQLEnv.sparkContext.addSparkListener(listener)
uiTab = if (SparkSQLEnv.sparkContext.getConf.getBoolean("spark.ui.enabled", true)) {
Some(new ThriftServerTab(SparkSQLEnv.sparkContext))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ private[hive] class SparkExecuteStatementOperation(
}
}

override def run(): Unit = {
override def runInternal(): Unit = {
setState(OperationState.PENDING)
setHasResultSet(true) // avoid no resultset for async run

if (!runInBackground) {
runInternal()
execute()
} else {
val parentSessionState = SessionState.get()
val hiveConf = getConfigForOperation()
Expand All @@ -157,24 +157,26 @@ private[hive] class SparkExecuteStatementOperation(
val doAsAction = new PrivilegedExceptionAction[Object]() {
override def run(): Object = {

Hive.set(sessionHive)
SessionState.setCurrentSessionState(parentSessionState)

// User information is part of the metastore client member in Hive
hiveContext.setSession(currentSqlSession)

// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader =
hiveContext.executionHive.state.getConf.getClassLoader
sessionHive.getConf.setClassLoader(executionHiveClassLoader)
parentSessionState.getConf.setClassLoader(executionHiveClassLoader)

Hive.set(sessionHive)
SessionState.setCurrentSessionState(parentSessionState)
try {
runInternal()
execute()
} catch {
case e: HiveSQLException =>
setOperationException(e)
log.error("Error running hive query: ", e)
}
return null
null
}
}

Expand Down Expand Up @@ -206,7 +208,7 @@ private[hive] class SparkExecuteStatementOperation(
}
}

override def runInternal(): Unit = {
private def execute(): Unit = {
statementId = UUID.randomUUID().toString
logInfo(s"Running query '$statement' with $statementId")
setState(OperationState.RUNNING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: Hiv
extends CLIService(hiveServer)
with ReflectedCompositeService {

private var sparkSqlSessionManager: SparkSQLSessionManager = null

override def init(hiveConf: HiveConf) {
setSuperField(this, "hiveConf", hiveConf)

val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext)
sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, hiveContext)
setSuperField(this, "sessionManager", sparkSqlSessionManager)
addService(sparkSqlSessionManager)
var sparkServiceUGI: UserGroupInformation = null
Expand All @@ -70,6 +72,35 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: Hiv
case _ => super.getInfo(sessionHandle, getInfoType)
}
}

private def withMetaContext[A](sessionHandle: SessionHandle, f: => A): A = {
val session = sparkSqlSessionManager.getSession(sessionHandle).getSessionState
hiveContext.executionMetaHive.withHiveState(session, f)
}

override def getCatalogs(sessionHandle: SessionHandle): OperationHandle =
withMetaContext(sessionHandle,
super.getCatalogs(sessionHandle))

override def getSchemas(sessionHandle: SessionHandle, catalogName: String, schemaName: String):
OperationHandle = withMetaContext(sessionHandle,
super.getSchemas(sessionHandle, catalogName, schemaName))

override def getTables(sessionHandle: SessionHandle, catalogName: String, schemaName: String,
tableName: String, tableTypes: java.util.List[String]): OperationHandle =
withMetaContext(sessionHandle,
super.getTables(sessionHandle, catalogName, schemaName, tableName, tableTypes))

override def getTableTypes(sessionHandle: SessionHandle): OperationHandle =
withMetaContext(sessionHandle, super.getTableTypes(sessionHandle))

override def getColumns(sessionHandle: SessionHandle, catalogName: String, schemaName: String,
tableName: String, columnName: String): OperationHandle = withMetaContext(sessionHandle,
super.getColumns(sessionHandle, catalogName, schemaName, tableName, columnName))

override def getFunctions(sessionHandle: SessionHandle, catalogName: String, schemaName: String,
functionName: String): OperationHandle = withMetaContext(sessionHandle,
super.getFunctions(sessionHandle, catalogName, schemaName, functionName))
}

private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
Expand Down
Loading