Skip to content

[SPARK-3632] ConnectionManager can run out of receive threads with authentication on #2484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
* If its acting as a client and trying to send a message to another ConnectionManager,
* it blocks the thread calling sendMessage until the SASL negotiation has occurred.
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
* and waits for the response from the server and does the handshake.
* and waits for the response from the server and does the handshake before sending
* the real message.
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
Expand Down
65 changes: 44 additions & 21 deletions core/src/main/scala/org/apache/spark/network/nio/Connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
import java.util.LinkedList

import org.apache.spark._

import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
import scala.collection.mutable.{ArrayBuffer, HashMap}

private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
val securityMgr: SecurityManager)
extends Logging {

var sparkSaslServer: SparkSaslServer = null
var sparkSaslClient: SparkSaslClient = null

def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
securityMgr_ : SecurityManager) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
id_, securityMgr_)
}

channel.configureBlocking(false)
Expand All @@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,

val remoteAddress = getRemoteAddress()

/**
* Used to synchronize client requests: client's work-related requests must
* wait until SASL authentication completes.
*/
private val authenticated = new Object()

def getAuthenticated(): Object = authenticated

def isSaslComplete(): Boolean

def resetForceReregister(): Boolean
Expand Down Expand Up @@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,

private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId, id_ : ConnectionId)
extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
remoteId_ : ConnectionManagerId, id_ : ConnectionId,
securityMgr_ : SecurityManager)
extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {

def isSaslComplete(): Boolean = {
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
}

private class Outbox {
val messages = new Queue[Message]()
val messages = new LinkedList[Message]()
val defaultChunkSize = 65536
var nextMessageToBeUsed = 0

def addMessage(message: Message) {
messages.synchronized {
/* messages += message */
messages.enqueue(message)
messages.add(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
Expand All @@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
val message = messages.dequeue()

val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
// only allow sending of security messages until sasl is complete
var pos = 0
var securityMsg: Message = null
while (pos < messages.size() && securityMsg == null) {
if (messages.get(pos).isSecurityNeg) {
securityMsg = messages.remove(pos)
}
pos = pos + 1
}
// didn't find any security messages and auth isn't completed so return
if (securityMsg == null) return None
securityMsg
} else {
messages.removeFirst()
}

val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages.enqueue(message)
messages.add(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug(
Expand Down Expand Up @@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
changeConnectionKeyInterest(DEFAULT_INTEREST)
}

def registerAfterAuth(): Unit = {
outbox.synchronized {
needForceReregister = true
}
if (channel.isConnected) {
registerInterest()
}
}

def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
Expand Down Expand Up @@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private[spark] class ReceivingConnection(
channel_ : SocketChannel,
selector_ : Selector,
id_ : ConnectionId)
extends Connection(channel_, selector_, id_) {
id_ : ConnectionId,
securityMgr_ : SecurityManager)
extends Connection(channel_, selector_, id_, securityMgr_) {

def isSaslComplete(): Boolean = {
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps

import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}
import org.apache.spark.util.Utils


private[nio] class ConnectionManager(
Expand Down Expand Up @@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)

// default to 30 second timeout waiting for authentication
private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)

private val handleMessageExecutor = new ThreadPoolExecutor(
Expand Down Expand Up @@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
while (newChannel != null) {
try {
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
securityManager)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
Expand Down Expand Up @@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll()
}
waitingConn.registerAfterAuth()
wakeupSelector()
return
} else {
var replyToken : Array[Byte] = null
Expand All @@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll()
}
waitingConn.registerAfterAuth()
wakeupSelector()
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
Expand Down Expand Up @@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
if (connection.isSaslComplete()) {
logDebug("Server sasl completed: " + connection.connectionId)
logDebug("Server sasl completed: " + connection.connectionId +
" for: " + connectionId)
} else {
logDebug("Server sasl not completed: " + connection.connectionId)
logDebug("Server sasl not completed: " + connection.connectionId +
" for: " + connectionId)
}
if (replyToken != null) {
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
Expand Down Expand Up @@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
" to: " + connManagerId)
} catch {
case e: Exception => {
logError("Error getting first response from the SaslClient.", e)
Expand All @@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
newConnectionId)
newConnectionId, securityManager)
logInfo("creating new sending connection for security! " + newConnectionId )
registerRequests.enqueue(newConnection)

Expand All @@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
connectionManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
newConnectionId)
newConnectionId, securityManager)
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)

newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
if (authEnabled) {
checkSendAuthFirst(connectionManagerId, connection)
}

message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)

if (authEnabled) {
// if we aren't authenticated yet lets block the senders until authentication completes
try {
connection.getAuthenticated().synchronized {
val clock = SystemClock
val startTime = clock.getTime()

while (!connection.isSaslComplete()) {
logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
// have timeout in case remote side never responds
connection.getAuthenticated().wait(500)
if (((clock.getTime() - startTime) >= (authTimeout * 1000))
&& (!connection.isSaslComplete())) {
// took to long to authenticate the connection, something probably went wrong
throw new Exception("Took to long for authentication to " + connectionManagerId +
", waited " + authTimeout + "seconds, failing.")
}
}
}
} catch {
case e: Exception => logError("Exception while waiting for authentication.", e)

// need to tell sender it failed
messageStatuses.synchronized {
val s = messageStatuses.get(message.id)
s match {
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
}
}
}
}
checkSendAuthFirst(connectionManagerId, connection)
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)

wakeupSelector()
}

Expand Down