Skip to content

Commit

Permalink
apache#11 SQLConf support thread local prop
Browse files Browse the repository at this point in the history
  • Loading branch information
hn5092 authored and Wayne1c committed Apr 15, 2020
1 parent 529ae3e commit 5204dfd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.{Map => JMap}
import scala.collection.mutable.HashMap
import scala.util.matching.Regex

import org.apache.commons.lang3.SerializationUtils

private object ConfigReader {

private val REF_RE = "\\$\\{(?:(\\w+?):)?(\\S+?)\\}".r
Expand Down Expand Up @@ -56,6 +58,17 @@ private[spark] class ConfigReader(conf: ConfigProvider) {
bindEnv(new EnvProvider())
bindSystem(new SystemProvider())

protected[spark] val localProperties =
new InheritableThreadLocal[java.util.HashMap[String, String]] {
override protected def childValue(parent: java.util.HashMap[String, String]):
java.util.HashMap[String, String] = {
// Note: make a clone such that changes in the parent properties aren't reflected in
// the those of the children threads, which has confusing semantics (SPARK-10563).
SerializationUtils.clone(parent)
}
override protected def initialValue(): java.util.HashMap[String, String] =
new java.util.HashMap[String, String]()
}
/**
* Binds a prefix to a provider. This method is not thread-safe and should be called
* before the instance is used to expand values.
Expand All @@ -76,7 +89,9 @@ private[spark] class ConfigReader(conf: ConfigProvider) {
/**
* Reads a configuration key from the default provider, and apply variable substitution.
*/
def get(key: String): Option[String] = conf.get(key).map(substitute)
def get(key: String): Option[String] = {
Option(localProperties.get().get(key)).orElse(conf.get(key)).map(substitute)
}

/**
* Perform variable substitution on the given input string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,13 @@ class SQLConf extends Serializable with Logging {
getOrElse(throw new NoSuchElementException(key))
}

def setLocalProperty(key: String, value: String): Unit = {
if (value == null) {
reader.localProperties.get().remove(key)
} else {
reader.localProperties.get().put(key, value)
}
}
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
* yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the
Expand Down Expand Up @@ -1168,8 +1175,13 @@ class SQLConf extends Serializable with Logging {
* Return all the configuration properties that have been set (i.e. not the default).
* This creates a new copy of the config properties in the form of a Map.
*/
def getAllConfs: immutable.Map[String, String] =
settings.synchronized { settings.asScala.toMap }
def getAllConfs: immutable.Map[String, String] = {
settings.synchronized {
var map = settings.asScala.toMap
reader.localProperties.get().asScala.foreach(entry => map += (entry._1 -> entry._2))
map
}
}

/**
* Return all the configuration definitions that have been defined in [[SQLConf]]. Each
Expand Down

0 comments on commit 5204dfd

Please sign in to comment.