Skip to content

Commit f1cd1bb

Browse files
committed
Clean up @sarutak's PR #1490 for [SPARK-2583]: ConnectionManager error reporting
Use Futures to signal failures, rather than exposing empty messages to user code.
1 parent 7399c6b commit f1cd1bb

File tree

6 files changed

+87
-76
lines changed

6 files changed

+87
-76
lines changed

core/src/main/scala/org/apache/spark/network/ConnectionManager.scala

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network
1919

20+
import java.io.IOException
2021
import java.nio._
2122
import java.nio.channels._
2223
import java.nio.channels.spi._
@@ -41,16 +42,26 @@ import org.apache.spark.util.{SystemClock, Utils}
4142
private[spark] class ConnectionManager(port: Int, conf: SparkConf,
4243
securityManager: SecurityManager) extends Logging {
4344

45+
/**
46+
* Used by sendMessageReliably to track messages being sent.
47+
* @param message the message that was sent
48+
* @param connectionManagerId the connection manager that sent this message
49+
* @param completionHandler callback that's invoked when the send has completed or failed
50+
*/
4451
class MessageStatus(
4552
val message: Message,
4653
val connectionManagerId: ConnectionManagerId,
4754
completionHandler: MessageStatus => Unit) {
4855

56+
/** This is non-None if message has been ack'd */
4957
var ackMessage: Option[Message] = None
50-
var attempted = false
51-
var acked = false
5258

53-
def markDone() { completionHandler(this) }
59+
def markDone(ackMessage: Option[Message]) {
60+
this.synchronized {
61+
this.ackMessage = ackMessage
62+
completionHandler(this)
63+
}
64+
}
5465
}
5566

5667
private val selector = SelectorProvider.provider.openSelector()
@@ -434,11 +445,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
434445
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
435446
.foreach(status => {
436447
logInfo("Notifying " + status)
437-
status.synchronized {
438-
status.attempted = true
439-
status.acked = false
440-
status.markDone()
441-
}
448+
status.markDone(None)
442449
})
443450

444451
messageStatuses.retain((i, status) => {
@@ -467,11 +474,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
467474
for (s <- messageStatuses.values
468475
if s.connectionManagerId == sendingConnectionManagerId) {
469476
logInfo("Notifying " + s)
470-
s.synchronized {
471-
s.attempted = true
472-
s.acked = false
473-
s.markDone()
474-
}
477+
s.markDone(None)
475478
}
476479

477480
messageStatuses.retain((i, status) => {
@@ -539,13 +542,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
539542
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
540543
securityMsg.getConnectionId.toString)
541544
val message = securityMsgResp.toBufferMessage
542-
if (message == null) throw new Exception("Error creating security message")
545+
if (message == null) throw new IOException("Error creating security message")
543546
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
544547
} catch {
545548
case e: Exception => {
546549
logError("Error handling sasl client authentication", e)
547550
waitingConn.close()
548-
throw new Exception("Error evaluating sasl response: " + e)
551+
throw new IOException("Error evaluating sasl response: " + e)
549552
}
550553
}
551554
}
@@ -653,12 +656,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
653656
}
654657
}
655658
}
656-
sentMessageStatus.synchronized {
657-
sentMessageStatus.ackMessage = Some(message)
658-
sentMessageStatus.attempted = true
659-
sentMessageStatus.acked = true
660-
sentMessageStatus.markDone()
661-
}
659+
sentMessageStatus.markDone(Some(message))
662660
} else {
663661
var ackMessage : Option[Message] = None
664662
try {
@@ -681,7 +679,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
681679
}
682680
} catch {
683681
case e: Exception => {
684-
logError(s"Exception was thrown during processing message", e)
682+
logError(s"Exception was thrown while processing message", e)
685683
val m = Message.createBufferMessage(bufferMessage.id)
686684
m.hasError = true
687685
ackMessage = Some(m)
@@ -802,11 +800,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
802800
case Some(msgStatus) => {
803801
messageStatuses -= message.id
804802
logInfo("Notifying " + msgStatus.connectionManagerId)
805-
msgStatus.synchronized {
806-
msgStatus.attempted = true
807-
msgStatus.acked = false
808-
msgStatus.markDone()
809-
}
803+
msgStatus.markDone(None)
810804
}
811805
case None => {
812806
logError("no messageStatus for failed message id: " + message.id)
@@ -825,11 +819,28 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
825819
selector.wakeup()
826820
}
827821

822+
/**
823+
* Send a message and block until an acknowldgment is received or an error occurs.
824+
* @param connectionManagerId the message's destination
825+
* @param message the message being sent
826+
* @return a Future that either returns the acknowledgment message or captures an exception.
827+
*/
828828
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
829-
: Future[Option[Message]] = {
830-
val promise = Promise[Option[Message]]
831-
val status = new MessageStatus(
832-
message, connectionManagerId, s => promise.success(s.ackMessage))
829+
: Future[Message] = {
830+
val promise = Promise[Message]()
831+
val status = new MessageStatus(message, connectionManagerId, s => {
832+
s.ackMessage match {
833+
case None => // Indicates a failure where we either never sent or never got ACK'd
834+
promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
835+
case Some(ackMessage) =>
836+
if (ackMessage.hasError) {
837+
promise.failure(
838+
new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
839+
} else {
840+
promise.success(ackMessage)
841+
}
842+
}
843+
})
833844
messageStatuses.synchronized {
834845
messageStatuses += ((message.id, status))
835846
}
@@ -838,7 +849,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
838849
}
839850

840851
def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
841-
message: Message): Option[Message] = {
852+
message: Message): Message = {
842853
Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
843854
}
844855

@@ -864,6 +875,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
864875

865876

866877
private[spark] object ConnectionManager {
878+
import ExecutionContext.Implicits.global
867879

868880
def main(args: Array[String]) {
869881
val conf = new SparkConf
@@ -919,8 +931,10 @@ private[spark] object ConnectionManager {
919931
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
920932
manager.sendMessageReliably(manager.id, bufferMessage)
921933
}).foreach(f => {
922-
val g = Await.result(f, 1 second)
923-
if (!g.isDefined) println("Failed")
934+
f.onFailure {
935+
case e => println("Failed due to " + e)
936+
}
937+
Await.ready(f, 1 second)
924938
})
925939
val finishTime = System.currentTimeMillis
926940

@@ -954,8 +968,10 @@ private[spark] object ConnectionManager {
954968
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
955969
manager.sendMessageReliably(manager.id, bufferMessage)
956970
}).foreach(f => {
957-
val g = Await.result(f, 1 second)
958-
if (!g.isDefined) println("Failed")
971+
f.onFailure {
972+
case e => println("Failed due to " + e)
973+
}
974+
Await.ready(f, 1 second)
959975
})
960976
val finishTime = System.currentTimeMillis
961977

@@ -984,8 +1000,10 @@ private[spark] object ConnectionManager {
9841000
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
9851001
manager.sendMessageReliably(manager.id, bufferMessage)
9861002
}).foreach(f => {
987-
val g = Await.result(f, 1 second)
988-
if (!g.isDefined) println("Failed")
1003+
f.onFailure {
1004+
case e => println("Failed due to " + e)
1005+
}
1006+
Await.ready(f, 1 second)
9891007
})
9901008
val finishTime = System.currentTimeMillis
9911009
Thread.sleep(1000)

core/src/main/scala/org/apache/spark/network/SenderTest.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.network
1919

2020
import java.nio.ByteBuffer
21+
import scala.util.Try
2122
import org.apache.spark.{SecurityManager, SparkConf}
2223

2324
private[spark] object SenderTest {
@@ -51,7 +52,7 @@ private[spark] object SenderTest {
5152
val dataMessage = Message.createBufferMessage(buffer.duplicate)
5253
val startTime = System.currentTimeMillis
5354
/* println("Started timer at " + startTime) */
54-
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
55+
val responseStr = Try(manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage))
5556
.map { response =>
5657
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
5758
new String(buffer.array, "utf-8")

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
2222
import scala.collection.mutable.ArrayBuffer
2323
import scala.collection.mutable.HashSet
2424
import scala.collection.mutable.Queue
25+
import scala.util.{Failure, Success}
2526

2627
import io.netty.buffer.ByteBuf
2728

@@ -118,31 +119,24 @@ object BlockFetcherIterator {
118119
bytesInFlight += req.size
119120
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
120121
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
121-
future.onSuccess {
122-
case Some(message) => {
122+
future.onComplete {
123+
case Success(message) => {
123124
val bufferMessage = message.asInstanceOf[BufferMessage]
124-
if (bufferMessage.hasError) {
125-
logError("Could not get block(s) from " + cmId)
126-
for ((blockId, size) <- req.blocks) {
127-
results.put(new FetchResult(blockId, -1, null))
128-
}
129-
} else {
130-
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
131-
for (blockMessage <- blockMessageArray) {
132-
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
133-
throw new SparkException(
134-
"Unexpected message " + blockMessage.getType + " received from " + cmId)
135-
}
136-
val blockId = blockMessage.getId
137-
val networkSize = blockMessage.getData.limit()
138-
results.put(new FetchResult(blockId, sizeMap(blockId),
139-
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
140-
_remoteBytesRead += networkSize
141-
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
125+
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
126+
for (blockMessage <- blockMessageArray) {
127+
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
128+
throw new SparkException(
129+
"Unexpected message " + blockMessage.getType + " received from " + cmId)
142130
}
131+
val blockId = blockMessage.getId
132+
val networkSize = blockMessage.getData.limit()
133+
results.put(new FetchResult(blockId, sizeMap(blockId),
134+
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
135+
_remoteBytesRead += networkSize
136+
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
143137
}
144138
}
145-
case None => {
139+
case Failure(exception) => {
146140
logError("Could not get block(s) from " + cmId)
147141
for ((blockId, size) <- req.blocks) {
148142
results.put(new FetchResult(blockId, -1, null))

core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import org.apache.spark.Logging
2323
import org.apache.spark.network._
2424
import org.apache.spark.util.Utils
2525

26+
import scala.util.{Failure, Success, Try}
27+
2628
/**
2729
* A network interface for BlockManager. Each slave should have one
2830
* BlockManagerWorker.
@@ -115,28 +117,28 @@ private[spark] object BlockManagerWorker extends Logging {
115117
val connectionManager = blockManager.connectionManager
116118
val blockMessage = BlockMessage.fromPutBlock(msg)
117119
val blockMessageArray = new BlockMessageArray(blockMessage)
118-
val resultMessage = connectionManager.sendMessageReliablySync(
119-
toConnManagerId, blockMessageArray.toBufferMessage)
120-
resultMessage.isDefined
120+
val resultMessage = Try(connectionManager.sendMessageReliablySync(
121+
toConnManagerId, blockMessageArray.toBufferMessage))
122+
resultMessage.isSuccess
121123
}
122124

123125
def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
124126
val blockManager = blockManagerWorker.blockManager
125127
val connectionManager = blockManager.connectionManager
126128
val blockMessage = BlockMessage.fromGetBlock(msg)
127129
val blockMessageArray = new BlockMessageArray(blockMessage)
128-
val responseMessage = connectionManager.sendMessageReliablySync(
129-
toConnManagerId, blockMessageArray.toBufferMessage)
130+
val responseMessage = Try(connectionManager.sendMessageReliablySync(
131+
toConnManagerId, blockMessageArray.toBufferMessage))
130132
responseMessage match {
131-
case Some(message) => {
133+
case Success(message) => {
132134
val bufferMessage = message.asInstanceOf[BufferMessage]
133135
logDebug("Response message received " + bufferMessage)
134136
BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
135137
logDebug("Found " + blockMessage)
136138
return blockMessage.getData
137139
})
138140
}
139-
case None => logDebug("No response message received")
141+
case Failure(exception) => logDebug("No response message received")
140142
}
141143
null
142144
}

core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.scalatest.FunSuite
2525
import scala.concurrent.{Await, TimeoutException}
2626
import scala.concurrent.duration._
2727
import scala.language.postfixOps
28+
import scala.util.Try
2829

2930
/**
3031
* Test the ConnectionManager with various security settings.
@@ -209,7 +210,6 @@ class ConnectionManagerSuite extends FunSuite {
209210
}).foreach(f => {
210211
try {
211212
val g = Await.result(f, 1 second)
212-
if (!g.isDefined) assert(false) else assert(true)
213213
} catch {
214214
case e: Exception => {
215215
assert(false)
@@ -240,9 +240,8 @@ class ConnectionManagerSuite extends FunSuite {
240240

241241
val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
242242

243-
val message = Await.result(future, 1 second)
244-
assert(message.isDefined)
245-
assert(message.get.hasError)
243+
val message = Try(Await.result(future, 1 second))
244+
assert(message.isFailure)
246245

247246
manager.stop()
248247
managerServer.stop()

core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
148148
val f = future {
149149
val message = Message.createBufferMessage(0)
150150
message.hasError = true
151-
val someMessage = Some(message)
152-
someMessage
151+
message
153152
}
154153
when(connManager.sendMessageReliably(any(),
155154
any())).thenReturn(f)
@@ -204,10 +203,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
204203
buffer.flip()
205204
arrayBuffer += buffer
206205

207-
val someMessage = Some(Message.createBufferMessage(arrayBuffer))
208-
209206
val f = future {
210-
someMessage
207+
Message.createBufferMessage(arrayBuffer)
211208
}
212209
when(connManager.sendMessageReliably(any(),
213210
any())).thenReturn(f)

0 commit comments

Comments
 (0)