Skip to content

Commit 4dc449a

Browse files
committed
Merge remote-tracking branch 'upstream/master' into dt-robustness
2 parents 7a61f7b + 4201d27 commit 4dc449a

File tree

13 files changed

+407
-94
lines changed

13 files changed

+407
-94
lines changed

core/src/main/scala/org/apache/spark/network/BufferMessage.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
4848
val security = if (isSecurityNeg) 1 else 0
4949
if (size == 0 && !gotChunkForSendingOnce) {
5050
val newChunk = new MessageChunk(
51-
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
51+
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
5252
gotChunkForSendingOnce = true
5353
return Some(newChunk)
5454
}
@@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
6666
}
6767
buffer.position(buffer.position + newBuffer.remaining)
6868
val newChunk = new MessageChunk(new MessageChunkHeader(
69-
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
69+
typ, id, size, newBuffer.remaining, ackId,
70+
hasError, security, senderAddress), newBuffer)
7071
gotChunkForSendingOnce = true
7172
return Some(newChunk)
7273
}
@@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
8889
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
8990
buffer.position(buffer.position + newBuffer.remaining)
9091
val newChunk = new MessageChunk(new MessageChunkHeader(
91-
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
92+
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
9293
return Some(newChunk)
9394
}
9495
None

core/src/main/scala/org/apache/spark/network/ConnectionManager.scala

Lines changed: 83 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network
1919

20+
import java.io.IOException
2021
import java.nio._
2122
import java.nio.channels._
2223
import java.nio.channels.spi._
@@ -45,16 +46,26 @@ private[spark] class ConnectionManager(
4546
name: String = "Connection manager")
4647
extends Logging {
4748

49+
/**
50+
* Used by sendMessageReliably to track messages being sent.
51+
* @param message the message that was sent
52+
* @param connectionManagerId the connection manager that sent this message
53+
* @param completionHandler callback that's invoked when the send has completed or failed
54+
*/
4855
class MessageStatus(
4956
val message: Message,
5057
val connectionManagerId: ConnectionManagerId,
5158
completionHandler: MessageStatus => Unit) {
5259

60+
/** This is non-None if message has been ack'd */
5361
var ackMessage: Option[Message] = None
54-
var attempted = false
55-
var acked = false
5662

57-
def markDone() { completionHandler(this) }
63+
def markDone(ackMessage: Option[Message]) {
64+
this.synchronized {
65+
this.ackMessage = ackMessage
66+
completionHandler(this)
67+
}
68+
}
5869
}
5970

6071
private val selector = SelectorProvider.provider.openSelector()
@@ -442,11 +453,7 @@ private[spark] class ConnectionManager(
442453
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
443454
.foreach(status => {
444455
logInfo("Notifying " + status)
445-
status.synchronized {
446-
status.attempted = true
447-
status.acked = false
448-
status.markDone()
449-
}
456+
status.markDone(None)
450457
})
451458

452459
messageStatuses.retain((i, status) => {
@@ -475,11 +482,7 @@ private[spark] class ConnectionManager(
475482
for (s <- messageStatuses.values
476483
if s.connectionManagerId == sendingConnectionManagerId) {
477484
logInfo("Notifying " + s)
478-
s.synchronized {
479-
s.attempted = true
480-
s.acked = false
481-
s.markDone()
482-
}
485+
s.markDone(None)
483486
}
484487

485488
messageStatuses.retain((i, status) => {
@@ -547,13 +550,13 @@ private[spark] class ConnectionManager(
547550
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
548551
securityMsg.getConnectionId.toString)
549552
val message = securityMsgResp.toBufferMessage
550-
if (message == null) throw new Exception("Error creating security message")
553+
if (message == null) throw new IOException("Error creating security message")
551554
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
552555
} catch {
553556
case e: Exception => {
554557
logError("Error handling sasl client authentication", e)
555558
waitingConn.close()
556-
throw new Exception("Error evaluating sasl response: " + e)
559+
throw new IOException("Error evaluating sasl response: ", e)
557560
}
558561
}
559562
}
@@ -661,34 +664,39 @@ private[spark] class ConnectionManager(
661664
}
662665
}
663666
}
664-
sentMessageStatus.synchronized {
665-
sentMessageStatus.ackMessage = Some(message)
666-
sentMessageStatus.attempted = true
667-
sentMessageStatus.acked = true
668-
sentMessageStatus.markDone()
669-
}
667+
sentMessageStatus.markDone(Some(message))
670668
} else {
671-
val ackMessage = if (onReceiveCallback != null) {
672-
logDebug("Calling back")
673-
onReceiveCallback(bufferMessage, connectionManagerId)
674-
} else {
675-
logDebug("Not calling back as callback is null")
676-
None
677-
}
669+
var ackMessage : Option[Message] = None
670+
try {
671+
ackMessage = if (onReceiveCallback != null) {
672+
logDebug("Calling back")
673+
onReceiveCallback(bufferMessage, connectionManagerId)
674+
} else {
675+
logDebug("Not calling back as callback is null")
676+
None
677+
}
678678

679-
if (ackMessage.isDefined) {
680-
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
681-
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
682-
+ ackMessage.get.getClass)
683-
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
684-
logDebug("Response to " + bufferMessage + " does not have ack id set")
685-
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
679+
if (ackMessage.isDefined) {
680+
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
681+
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
682+
+ ackMessage.get.getClass)
683+
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
684+
logDebug("Response to " + bufferMessage + " does not have ack id set")
685+
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
686+
}
687+
}
688+
} catch {
689+
case e: Exception => {
690+
logError(s"Exception was thrown while processing message", e)
691+
val m = Message.createBufferMessage(bufferMessage.id)
692+
m.hasError = true
693+
ackMessage = Some(m)
686694
}
695+
} finally {
696+
sendMessage(connectionManagerId, ackMessage.getOrElse {
697+
Message.createBufferMessage(bufferMessage.id)
698+
})
687699
}
688-
689-
sendMessage(connectionManagerId, ackMessage.getOrElse {
690-
Message.createBufferMessage(bufferMessage.id)
691-
})
692700
}
693701
}
694702
case _ => throw new Exception("Unknown type message received")
@@ -800,11 +808,7 @@ private[spark] class ConnectionManager(
800808
case Some(msgStatus) => {
801809
messageStatuses -= message.id
802810
logInfo("Notifying " + msgStatus.connectionManagerId)
803-
msgStatus.synchronized {
804-
msgStatus.attempted = true
805-
msgStatus.acked = false
806-
msgStatus.markDone()
807-
}
811+
msgStatus.markDone(None)
808812
}
809813
case None => {
810814
logError("no messageStatus for failed message id: " + message.id)
@@ -823,23 +827,35 @@ private[spark] class ConnectionManager(
823827
selector.wakeup()
824828
}
825829

830+
/**
831+
* Send a message and block until an acknowldgment is received or an error occurs.
832+
* @param connectionManagerId the message's destination
833+
* @param message the message being sent
834+
* @return a Future that either returns the acknowledgment message or captures an exception.
835+
*/
826836
def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
827-
: Future[Option[Message]] = {
828-
val promise = Promise[Option[Message]]
829-
val status = new MessageStatus(
830-
message, connectionManagerId, s => promise.success(s.ackMessage))
837+
: Future[Message] = {
838+
val promise = Promise[Message]()
839+
val status = new MessageStatus(message, connectionManagerId, s => {
840+
s.ackMessage match {
841+
case None => // Indicates a failure where we either never sent or never got ACK'd
842+
promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
843+
case Some(ackMessage) =>
844+
if (ackMessage.hasError) {
845+
promise.failure(
846+
new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
847+
} else {
848+
promise.success(ackMessage)
849+
}
850+
}
851+
})
831852
messageStatuses.synchronized {
832853
messageStatuses += ((message.id, status))
833854
}
834855
sendMessage(connectionManagerId, message)
835856
promise.future
836857
}
837858

838-
def sendMessageReliablySync(connectionManagerId: ConnectionManagerId,
839-
message: Message): Option[Message] = {
840-
Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
841-
}
842-
843859
def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
844860
onReceiveCallback = callback
845861
}
@@ -862,6 +878,7 @@ private[spark] class ConnectionManager(
862878

863879

864880
private[spark] object ConnectionManager {
881+
import ExecutionContext.Implicits.global
865882

866883
def main(args: Array[String]) {
867884
val conf = new SparkConf
@@ -896,7 +913,7 @@ private[spark] object ConnectionManager {
896913

897914
(0 until count).map(i => {
898915
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
899-
manager.sendMessageReliablySync(manager.id, bufferMessage)
916+
Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf)
900917
})
901918
println("--------------------------")
902919
println()
@@ -917,8 +934,10 @@ private[spark] object ConnectionManager {
917934
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
918935
manager.sendMessageReliably(manager.id, bufferMessage)
919936
}).foreach(f => {
920-
val g = Await.result(f, 1 second)
921-
if (!g.isDefined) println("Failed")
937+
f.onFailure {
938+
case e => println("Failed due to " + e)
939+
}
940+
Await.ready(f, 1 second)
922941
})
923942
val finishTime = System.currentTimeMillis
924943

@@ -952,8 +971,10 @@ private[spark] object ConnectionManager {
952971
val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
953972
manager.sendMessageReliably(manager.id, bufferMessage)
954973
}).foreach(f => {
955-
val g = Await.result(f, 1 second)
956-
if (!g.isDefined) println("Failed")
974+
f.onFailure {
975+
case e => println("Failed due to " + e)
976+
}
977+
Await.ready(f, 1 second)
957978
})
958979
val finishTime = System.currentTimeMillis
959980

@@ -982,8 +1003,10 @@ private[spark] object ConnectionManager {
9821003
val bufferMessage = Message.createBufferMessage(buffer.duplicate)
9831004
manager.sendMessageReliably(manager.id, bufferMessage)
9841005
}).foreach(f => {
985-
val g = Await.result(f, 1 second)
986-
if (!g.isDefined) println("Failed")
1006+
f.onFailure {
1007+
case e => println("Failed due to " + e)
1008+
}
1009+
Await.ready(f, 1 second)
9871010
})
9881011
val finishTime = System.currentTimeMillis
9891012
Thread.sleep(1000)

core/src/main/scala/org/apache/spark/network/Message.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
2828
var startTime = -1L
2929
var finishTime = -1L
3030
var isSecurityNeg = false
31+
var hasError = false
3132

3233
def size: Int
3334

@@ -87,6 +88,7 @@ private[spark] object Message {
8788
case BUFFER_MESSAGE => new BufferMessage(header.id,
8889
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
8990
}
91+
newMessage.hasError = header.hasError
9092
newMessage.senderAddress = header.address
9193
newMessage
9294
}

core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
2727
val totalSize: Int,
2828
val chunkSize: Int,
2929
val other: Int,
30+
val hasError: Boolean,
3031
val securityNeg: Int,
3132
val address: InetSocketAddress) {
3233
lazy val buffer = {
@@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
4142
putInt(totalSize).
4243
putInt(chunkSize).
4344
putInt(other).
45+
put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
4446
putInt(securityNeg).
4547
putInt(ip.size).
4648
put(ip).
@@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(
5658

5759

5860
private[spark] object MessageChunkHeader {
59-
val HEADER_SIZE = 44
61+
val HEADER_SIZE = 45
6062

6163
def create(buffer: ByteBuffer): MessageChunkHeader = {
6264
if (buffer.remaining != HEADER_SIZE) {
@@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
6769
val totalSize = buffer.getInt()
6870
val chunkSize = buffer.getInt()
6971
val other = buffer.getInt()
72+
val hasError = buffer.get() != 0
7073
val securityNeg = buffer.getInt()
7174
val ipSize = buffer.getInt()
7275
val ipBytes = new Array[Byte](ipSize)
7376
buffer.get(ipBytes)
7477
val ip = InetAddress.getByAddress(ipBytes)
7578
val port = buffer.getInt()
76-
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
79+
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
7780
new InetSocketAddress(ip, port))
7881
}
7982
}

core/src/main/scala/org/apache/spark/network/SenderTest.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ package org.apache.spark.network
2020
import java.nio.ByteBuffer
2121
import org.apache.spark.{SecurityManager, SparkConf}
2222

23+
import scala.concurrent.Await
24+
import scala.concurrent.duration.Duration
25+
import scala.util.Try
26+
2327
private[spark] object SenderTest {
2428
def main(args: Array[String]) {
2529

@@ -51,7 +55,8 @@ private[spark] object SenderTest {
5155
val dataMessage = Message.createBufferMessage(buffer.duplicate)
5256
val startTime = System.currentTimeMillis
5357
/* println("Started timer at " + startTime) */
54-
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
58+
val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage)
59+
val responseStr: String = Try(Await.result(promise, Duration.Inf))
5560
.map { response =>
5661
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
5762
new String(buffer.array, "utf-8")

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.LinkedBlockingQueue
2222
import scala.collection.mutable.ArrayBuffer
2323
import scala.collection.mutable.HashSet
2424
import scala.collection.mutable.Queue
25+
import scala.util.{Failure, Success}
2526

2627
import io.netty.buffer.ByteBuf
2728

@@ -118,8 +119,8 @@ object BlockFetcherIterator {
118119
bytesInFlight += req.size
119120
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
120121
val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
121-
future.onSuccess {
122-
case Some(message) => {
122+
future.onComplete {
123+
case Success(message) => {
123124
val bufferMessage = message.asInstanceOf[BufferMessage]
124125
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
125126
for (blockMessage <- blockMessageArray) {
@@ -135,8 +136,8 @@ object BlockFetcherIterator {
135136
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
136137
}
137138
}
138-
case None => {
139-
logError("Could not get block(s) from " + cmId)
139+
case Failure(exception) => {
140+
logError("Could not get block(s) from " + cmId, exception)
140141
for ((blockId, size) <- req.blocks) {
141142
results.put(new FetchResult(blockId, -1, null))
142143
}

0 commit comments

Comments
 (0)