Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1189: Add Security to Spark - Akka, Http, ConnectionManager, UI use servlets #33

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Pass securityManager and SparkConf around where we can. Switch to use…
… sparkConf for reading config whereever possible.

Added ConnectionManagerSuite unit tests.
  • Loading branch information
tgravescs committed Mar 6, 2014
commit 13733e1532cbd3fcd0bef59d4078771bf58892d2
45 changes: 22 additions & 23 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ import scala.collection.mutable.ArrayBuffer
* filters to do authentication. That authentication then happens via the ResourceManager Proxy
* and Spark will use that to do authorization against the view acls.
*
* For other Spark deployments, the shared secret must be specified via the SPARK_SECRET
* environment variable. This isn't ideal but it means only the user who starts the process
* has access to view that variable.
* For other Spark deployments, the shared secret must be specified via the
* spark.authenticate.secret config.
* All the nodes (Master and Workers) and the applications need to have the same shared secret.
* This again is not ideal as one user could potentially affect another users application.
* This should be enhanced in the future to provide better protection.
Expand All @@ -133,23 +132,24 @@ import scala.collection.mutable.ArrayBuffer
* authorization. If not filter is in place the user is generally null and no authorization
* can take place.
*/
private[spark] class SecurityManager extends Logging {

private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {

// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"

private val authOn = System.getProperty("spark.authenticate", "false").toBoolean
private val uiAclsOn = System.getProperty("spark.ui.acls.enable", "false").toBoolean
private val authOn = sparkConf.getBoolean("spark.authenticate", false)
private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false)

// always add the current user and SPARK_USER to the viewAcls
private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""),
Option(System.getenv("SPARK_USER")).getOrElse(""))
aclUsers ++= System.getProperty("spark.ui.view.acls", "").split(',')
aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',')
private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet

private val secretKey = generateSecretKey()
logInfo("SecurityManager, is authentication enabled: " + authOn +
" are ui acls enabled: " + uiAclsOn)
" are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())

// Set our own authenticator to properly negotiate user/password for HTTP connections.
// This is needed by the HTTP client fetching from the HttpServer. Put here so its
Expand All @@ -176,34 +176,33 @@ private[spark] class SecurityManager extends Logging {
* The way the key is stored depends on the Spark deployment mode. Yarn
* uses the Hadoop UGI.
*
* For non-Yarn deployments, If the environment variable is not set
* we throw an exception.
* For non-Yarn deployments, If the config variable is not set
* we throw an exception.
*/
private def generateSecretKey(): String = {
if (!isAuthenticationEnabled) return null
// first check to see if the secret is already set, else generate a new one if on yarn
if (SparkHadoopUtil.get.isYarnMode) {
val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey)
if (secretKey != null) {
logDebug("in yarn mode, getting secret from credentials")
return new Text(secretKey).toString
} else {
logDebug("getSecretKey: yarn mode, secret key from credentials is null")
}
}
val secret = System.getProperty("SPARK_SECRET", System.getenv("SPARK_SECRET"))
if (secret != null && !secret.isEmpty()) return secret
val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
// generate one
akka.util.Crypt.generateSecureCookie
} else {
throw new Exception("Error: a secret key must be specified via SPARK_SECRET env variable")
}
if (SparkHadoopUtil.get.isYarnMode) {
// if we generated the secret then we must be the first so lets set it so t
val cookie = akka.util.Crypt.generateSecureCookie
// if we generated the secret then we must be the first so lets set it so t
// gets used by everyone else
SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, sCookie)
SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie)
logInfo("adding secret to credentials in yarn mode")
cookie
} else {
// user must have set spark.authenticate.secret config
sparkConf.getOption("spark.authenticate.secret") match {
case Some(value) => value
case None => throw new Exception("Error: a secret key must be specified via the " +
"spark.authenticate.secret config")
}
}
sCookie
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ class SparkContext(
addedFiles(key) = System.currentTimeMillis

// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager)

logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {

val securityManager = new SecurityManager()
val securityManager = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
securityManager = securityManager)

Expand Down Expand Up @@ -197,9 +197,9 @@ object SparkEnv extends Logging {
conf.set("spark.fileserver.uri", httpFileServer.serverUri)

val metricsSystem = if (isDriver) {
MetricsSystem.createMetricsSystem("driver", conf)
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
MetricsSystem.createMetricsSystem("executor", conf)
MetricsSystem.createMetricsSystem("executor", conf, securityManager)
}
metricsSystem.start()

Expand Down
42 changes: 25 additions & 17 deletions core/src/main/scala/org/apache/spark/SparkSaslClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,26 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg
* @return response to challenge if needed
*/
def firstToken(): Array[Byte] = {
val saslToken: Array[Byte] =
if (saslClient.hasInitialResponse()) {
logDebug("has initial response")
saslClient.evaluateChallenge(new Array[Byte](0))
} else {
new Array[Byte](0)
}
saslToken
synchronized {
val saslToken: Array[Byte] =
if (saslClient != null && saslClient.hasInitialResponse()) {
logDebug("has initial response")
saslClient.evaluateChallenge(new Array[Byte](0))
} else {
new Array[Byte](0)
}
saslToken
}
}

/**
* Determines whether the authentication exchange has completed.
* @return true is complete, otherwise false
*/
def isComplete(): Boolean = {
saslClient.isComplete()
synchronized {
if (saslClient != null) saslClient.isComplete() else false
}
}

/**
Expand All @@ -76,21 +80,25 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg
* @return client's response SASL token
*/
def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
saslClient.evaluateChallenge(saslTokenMessage)
synchronized {
if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
}
}

/**
* Disposes of any system resources or security-sensitive information the
* SaslClient might be using.
*/
def dispose() {
if (saslClient != null) {
try {
saslClient.dispose()
} catch {
case e: SaslException => // ignored
} finally {
saslClient = null
synchronized {
if (saslClient != null) {
try {
saslClient.dispose()
} catch {
case e: SaslException => // ignored
} finally {
saslClient = null
}
}
}
}
Expand Down
24 changes: 15 additions & 9 deletions core/src/main/scala/org/apache/spark/SparkSaslServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
* @return true is complete, otherwise false
*/
def isComplete(): Boolean = {
saslServer.isComplete()
synchronized {
if (saslServer != null) saslServer.isComplete() else false
}
}

/**
Expand All @@ -56,21 +58,25 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
* @return response to send back to the server.
*/
def response(token: Array[Byte]): Array[Byte] = {
saslServer.evaluateResponse(token)
synchronized {
if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
}
}

/**
* Disposes of any system resources or security-sensitive information the
* SaslServer might be using.
*/
def dispose() {
if (saslServer != null) {
try {
saslServer.dispose()
} catch {
case e: SaslException => // ignore
} finally {
saslServer = null
synchronized {
if (saslServer != null) {
try {
saslServer.dispose()
} catch {
case e: SaslException => // ignore
} finally {
saslServer = null
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/deploy/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ object Client {
// TODO: See if we can initialize akka so return messages are sent back using the same TCP
// flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
"driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager)
"driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf))

actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ private[spark] object TestClient {

def main(args: Array[String]) {
val url = args(0)
val conf = new SparkConf
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
conf = new SparkConf, securityManager = new SecurityManager())
conf = conf, securityManager = new SecurityManager(conf))
val desc = new ApplicationDescription(
"TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
Some("dummy-spark-home"), "ignored")
Expand Down
14 changes: 9 additions & 5 deletions core/src/main/scala/org/apache/spark/deploy/master/Master.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
import org.apache.spark.deploy.master.DriverState.DriverState

private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
private[spark] class Master(host: String, port: Int, webUiPort: Int,
val securityMgr: SecurityManager) extends Actor with Logging {
import context.dispatcher // to use Akka's scheduler.schedule()

val conf = new SparkConf
Expand Down Expand Up @@ -71,8 +72,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act

Utils.checkHost(host, "Expected hostname")

val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
securityMgr)
val masterSource = new MasterSource(this)

val webUi = new MasterWebUI(this, webUiPort)
Expand Down Expand Up @@ -712,9 +714,11 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
: (ActorSystem, Int, Int) =
{
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
securityManager = new SecurityManager)
val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
securityManager = securityMgr)
val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
securityMgr), actorName)
val timeout = AkkaUtils.askTimeout(conf)
val respFuture = actor.ask(RequestWebUIPort)(timeout)
val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {

def start() {
try {
val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get))
Expand All @@ -63,10 +63,14 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
val handlers = metricsHandlers ++ Seq[ServletContextHandler](
createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"),
createServletHandler("/app/json",
(request: HttpServletRequest) => applicationPage.renderJson(request)),
createServletHandler("/app", (request: HttpServletRequest) => applicationPage.render(request)),
createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request))
createServlet((request: HttpServletRequest) => applicationPage.renderJson(request),
master.securityMgr)),
createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage
.render(request), master.securityMgr)),
createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
.renderJson(request), master.securityMgr)),
createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
(request), master.securityMgr))
)

def stop() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ object DriverWrapper {
def main(args: Array[String]) {
args.toList match {
case workerUrl :: mainClass :: extraArgs =>
val conf = new SparkConf()
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
Utils.localHostName(), 0, false, new SparkConf(), new SecurityManager())
Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")

// Delegate to supplied main class
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ private[spark] class Worker(
actorSystemName: String,
actorName: String,
workDirPath: String = null,
val conf: SparkConf)
val conf: SparkConf,
val securityMgr: SecurityManager)
extends Actor with Logging {
import context.dispatcher

Expand Down Expand Up @@ -92,7 +93,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0

val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf)
val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)

def coresFree: Int = cores - coresUsed
Expand Down Expand Up @@ -348,10 +349,11 @@ private[spark] object Worker {
val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
conf = conf, securityManager = new SecurityManager)
conf = conf, securityManager = securityMgr)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
masterUrls, systemName, actorName, workDir, conf), name = actorName)
masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
*/
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
extends Logging {
extends Logging {
val timeout = AkkaUtils.askTimeout(worker.conf)
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
Expand All @@ -49,15 +49,19 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I

val handlers = metricsHandlers ++ Seq[ServletContextHandler](
createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"),
createServletHandler("/log", (request: HttpServletRequest) => log(request)),
createServletHandler("/logPage", (request: HttpServletRequest) => logPage(request)),
createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request))
createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request),
worker.securityMgr)),
createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage
(request), worker.securityMgr)),
createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
.renderJson(request), worker.securityMgr)),
createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
(request), worker.securityMgr))
)

def start() {
try {
val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Worker web UI at http://%s:%d".format(host, bPort))
Expand Down
Loading