Skip to content

Commit

Permalink
[SPARK-3632] ConnectionManager can run out of receive threads with au…
Browse files Browse the repository at this point in the history
…thentication on

If you turn authentication on and you are using a lot of executors. There is a chance that all the of the threads in the handleMessageExecutor could be waiting to send a message because they are blocked waiting on authentication to happen. This can cause a temporary deadlock until the connection times out.

To fix it, I got rid of the wait/notify and use a single outbox but only send security messages from it until authentication has completed.

Author: Thomas Graves <tgraves@apache.org>

Closes apache#2484 from tgravescs/cm_threads_auth and squashes the following commits:

a0a961d [Thomas Graves] give it a type
b6bc80b [Thomas Graves] Rework comments
d6d4175 [Thomas Graves] update from comments
081b765 [Thomas Graves] cleanup
4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for authentication
  • Loading branch information
tgravescs authored and rxin committed Oct 2, 2014
1 parent 5db78e6 commit 127e97b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 79 deletions.
5 changes: 2 additions & 3 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
* If its acting as a client and trying to send a message to another ConnectionManager,
* it blocks the thread calling sendMessage until the SASL negotiation has occurred.
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
* and waits for the response from the server and does the handshake.
* and waits for the response from the server and does the handshake before sending
* the real message.
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
Expand Down
65 changes: 44 additions & 21 deletions core/src/main/scala/org/apache/spark/network/nio/Connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
import java.util.LinkedList

import org.apache.spark._

import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
import scala.collection.mutable.{ArrayBuffer, HashMap}

private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
val securityMgr: SecurityManager)
extends Logging {

var sparkSaslServer: SparkSaslServer = null
var sparkSaslClient: SparkSaslClient = null

def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
securityMgr_ : SecurityManager) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
id_, securityMgr_)
}

channel.configureBlocking(false)
Expand All @@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,

val remoteAddress = getRemoteAddress()

/**
* Used to synchronize client requests: client's work-related requests must
* wait until SASL authentication completes.
*/
private val authenticated = new Object()

def getAuthenticated(): Object = authenticated

def isSaslComplete(): Boolean

def resetForceReregister(): Boolean
Expand Down Expand Up @@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,

private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId, id_ : ConnectionId)
extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
remoteId_ : ConnectionManagerId, id_ : ConnectionId,
securityMgr_ : SecurityManager)
extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {

def isSaslComplete(): Boolean = {
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
}

private class Outbox {
val messages = new Queue[Message]()
val messages = new LinkedList[Message]()
val defaultChunkSize = 65536
var nextMessageToBeUsed = 0

def addMessage(message: Message) {
messages.synchronized {
/* messages += message */
messages.enqueue(message)
messages.add(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
Expand All @@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
val message = messages.dequeue()

val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
// only allow sending of security messages until sasl is complete
var pos = 0
var securityMsg: Message = null
while (pos < messages.size() && securityMsg == null) {
if (messages.get(pos).isSecurityNeg) {
securityMsg = messages.remove(pos)
}
pos = pos + 1
}
// didn't find any security messages and auth isn't completed so return
if (securityMsg == null) return None
securityMsg
} else {
messages.removeFirst()
}

val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages.enqueue(message)
messages.add(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug(
Expand Down Expand Up @@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
changeConnectionKeyInterest(DEFAULT_INTEREST)
}

def registerAfterAuth(): Unit = {
outbox.synchronized {
needForceReregister = true
}
if (channel.isConnected) {
registerInterest()
}
}

def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
Expand Down Expand Up @@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private[spark] class ReceivingConnection(
channel_ : SocketChannel,
selector_ : Selector,
id_ : ConnectionId)
extends Connection(channel_, selector_, id_) {
id_ : ConnectionId,
securityMgr_ : SecurityManager)
extends Connection(channel_, selector_, id_, securityMgr_) {

def isSaslComplete(): Boolean = {
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps

import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}
import org.apache.spark.util.Utils


private[nio] class ConnectionManager(
Expand Down Expand Up @@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)

// default to 30 second timeout waiting for authentication
private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)

private val handleMessageExecutor = new ThreadPoolExecutor(
Expand Down Expand Up @@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
while (newChannel != null) {
try {
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
securityManager)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
Expand Down Expand Up @@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll()
}
waitingConn.registerAfterAuth()
wakeupSelector()
return
} else {
var replyToken : Array[Byte] = null
Expand All @@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll()
}
waitingConn.registerAfterAuth()
wakeupSelector()
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
Expand Down Expand Up @@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
if (connection.isSaslComplete()) {
logDebug("Server sasl completed: " + connection.connectionId)
logDebug("Server sasl completed: " + connection.connectionId +
" for: " + connectionId)
} else {
logDebug("Server sasl not completed: " + connection.connectionId)
logDebug("Server sasl not completed: " + connection.connectionId +
" for: " + connectionId)
}
if (replyToken != null) {
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
Expand Down Expand Up @@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
" to: " + connManagerId)
} catch {
case e: Exception => {
logError("Error getting first response from the SaslClient.", e)
Expand All @@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
newConnectionId)
newConnectionId, securityManager)
logInfo("creating new sending connection for security! " + newConnectionId )
registerRequests.enqueue(newConnection)

Expand All @@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
connectionManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
newConnectionId)
newConnectionId, securityManager)
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)

newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
if (authEnabled) {
checkSendAuthFirst(connectionManagerId, connection)
}

message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)

if (authEnabled) {
// if we aren't authenticated yet lets block the senders until authentication completes
try {
connection.getAuthenticated().synchronized {
val clock = SystemClock
val startTime = clock.getTime()

while (!connection.isSaslComplete()) {
logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
// have timeout in case remote side never responds
connection.getAuthenticated().wait(500)
if (((clock.getTime() - startTime) >= (authTimeout * 1000))
&& (!connection.isSaslComplete())) {
// took to long to authenticate the connection, something probably went wrong
throw new Exception("Took to long for authentication to " + connectionManagerId +
", waited " + authTimeout + "seconds, failing.")
}
}
}
} catch {
case e: Exception => logError("Exception while waiting for authentication.", e)

// need to tell sender it failed
messageStatuses.synchronized {
val s = messageStatuses.get(message.id)
s match {
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
}
}
}
}
checkSendAuthFirst(connectionManagerId, connection)
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)

wakeupSelector()
}

Expand Down

0 comments on commit 127e97b

Please sign in to comment.