17
17
18
18
package org .apache .spark .network
19
19
20
+ import java .io .IOException
20
21
import java .nio ._
21
22
import java .nio .channels ._
22
23
import java .nio .channels .spi ._
@@ -45,16 +46,26 @@ private[spark] class ConnectionManager(
45
46
name : String = " Connection manager" )
46
47
extends Logging {
47
48
49
+ /**
50
+ * Used by sendMessageReliably to track messages being sent.
51
+ * @param message the message that was sent
52
+ * @param connectionManagerId the connection manager that sent this message
53
+ * @param completionHandler callback that's invoked when the send has completed or failed
54
+ */
48
55
class MessageStatus (
49
56
val message : Message ,
50
57
val connectionManagerId : ConnectionManagerId ,
51
58
completionHandler : MessageStatus => Unit ) {
52
59
60
+ /** This is non-None if message has been ack'd */
53
61
var ackMessage : Option [Message ] = None
54
- var attempted = false
55
- var acked = false
56
62
57
- def markDone () { completionHandler(this ) }
63
+ def markDone (ackMessage : Option [Message ]) {
64
+ this .synchronized {
65
+ this .ackMessage = ackMessage
66
+ completionHandler(this )
67
+ }
68
+ }
58
69
}
59
70
60
71
private val selector = SelectorProvider .provider.openSelector()
@@ -442,11 +453,7 @@ private[spark] class ConnectionManager(
442
453
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
443
454
.foreach(status => {
444
455
logInfo(" Notifying " + status)
445
- status.synchronized {
446
- status.attempted = true
447
- status.acked = false
448
- status.markDone()
449
- }
456
+ status.markDone(None )
450
457
})
451
458
452
459
messageStatuses.retain((i, status) => {
@@ -475,11 +482,7 @@ private[spark] class ConnectionManager(
475
482
for (s <- messageStatuses.values
476
483
if s.connectionManagerId == sendingConnectionManagerId) {
477
484
logInfo(" Notifying " + s)
478
- s.synchronized {
479
- s.attempted = true
480
- s.acked = false
481
- s.markDone()
482
- }
485
+ s.markDone(None )
483
486
}
484
487
485
488
messageStatuses.retain((i, status) => {
@@ -547,13 +550,13 @@ private[spark] class ConnectionManager(
547
550
val securityMsgResp = SecurityMessage .fromResponse(replyToken,
548
551
securityMsg.getConnectionId.toString)
549
552
val message = securityMsgResp.toBufferMessage
550
- if (message == null ) throw new Exception (" Error creating security message" )
553
+ if (message == null ) throw new IOException (" Error creating security message" )
551
554
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
552
555
} catch {
553
556
case e : Exception => {
554
557
logError(" Error handling sasl client authentication" , e)
555
558
waitingConn.close()
556
- throw new Exception (" Error evaluating sasl response: " + e)
559
+ throw new IOException (" Error evaluating sasl response: " , e)
557
560
}
558
561
}
559
562
}
@@ -661,34 +664,39 @@ private[spark] class ConnectionManager(
661
664
}
662
665
}
663
666
}
664
- sentMessageStatus.synchronized {
665
- sentMessageStatus.ackMessage = Some (message)
666
- sentMessageStatus.attempted = true
667
- sentMessageStatus.acked = true
668
- sentMessageStatus.markDone()
669
- }
667
+ sentMessageStatus.markDone(Some (message))
670
668
} else {
671
- val ackMessage = if (onReceiveCallback != null ) {
672
- logDebug(" Calling back" )
673
- onReceiveCallback(bufferMessage, connectionManagerId)
674
- } else {
675
- logDebug(" Not calling back as callback is null" )
676
- None
677
- }
669
+ var ackMessage : Option [Message ] = None
670
+ try {
671
+ ackMessage = if (onReceiveCallback != null ) {
672
+ logDebug(" Calling back" )
673
+ onReceiveCallback(bufferMessage, connectionManagerId)
674
+ } else {
675
+ logDebug(" Not calling back as callback is null" )
676
+ None
677
+ }
678
678
679
- if (ackMessage.isDefined) {
680
- if (! ackMessage.get.isInstanceOf [BufferMessage ]) {
681
- logDebug(" Response to " + bufferMessage + " is not a buffer message, it is of type "
682
- + ackMessage.get.getClass)
683
- } else if (! ackMessage.get.asInstanceOf [BufferMessage ].hasAckId) {
684
- logDebug(" Response to " + bufferMessage + " does not have ack id set" )
685
- ackMessage.get.asInstanceOf [BufferMessage ].ackId = bufferMessage.id
679
+ if (ackMessage.isDefined) {
680
+ if (! ackMessage.get.isInstanceOf [BufferMessage ]) {
681
+ logDebug(" Response to " + bufferMessage + " is not a buffer message, it is of type "
682
+ + ackMessage.get.getClass)
683
+ } else if (! ackMessage.get.asInstanceOf [BufferMessage ].hasAckId) {
684
+ logDebug(" Response to " + bufferMessage + " does not have ack id set" )
685
+ ackMessage.get.asInstanceOf [BufferMessage ].ackId = bufferMessage.id
686
+ }
687
+ }
688
+ } catch {
689
+ case e : Exception => {
690
+ logError(s " Exception was thrown while processing message " , e)
691
+ val m = Message .createBufferMessage(bufferMessage.id)
692
+ m.hasError = true
693
+ ackMessage = Some (m)
686
694
}
695
+ } finally {
696
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
697
+ Message .createBufferMessage(bufferMessage.id)
698
+ })
687
699
}
688
-
689
- sendMessage(connectionManagerId, ackMessage.getOrElse {
690
- Message .createBufferMessage(bufferMessage.id)
691
- })
692
700
}
693
701
}
694
702
case _ => throw new Exception (" Unknown type message received" )
@@ -800,11 +808,7 @@ private[spark] class ConnectionManager(
800
808
case Some (msgStatus) => {
801
809
messageStatuses -= message.id
802
810
logInfo(" Notifying " + msgStatus.connectionManagerId)
803
- msgStatus.synchronized {
804
- msgStatus.attempted = true
805
- msgStatus.acked = false
806
- msgStatus.markDone()
807
- }
811
+ msgStatus.markDone(None )
808
812
}
809
813
case None => {
810
814
logError(" no messageStatus for failed message id: " + message.id)
@@ -823,23 +827,35 @@ private[spark] class ConnectionManager(
823
827
selector.wakeup()
824
828
}
825
829
830
+ /**
831
+ * Send a message and block until an acknowldgment is received or an error occurs.
832
+ * @param connectionManagerId the message's destination
833
+ * @param message the message being sent
834
+ * @return a Future that either returns the acknowledgment message or captures an exception.
835
+ */
826
836
def sendMessageReliably (connectionManagerId : ConnectionManagerId , message : Message )
827
- : Future [Option [Message ]] = {
828
- val promise = Promise [Option [Message ]]
829
- val status = new MessageStatus (
830
- message, connectionManagerId, s => promise.success(s.ackMessage))
837
+ : Future [Message ] = {
838
+ val promise = Promise [Message ]()
839
+ val status = new MessageStatus (message, connectionManagerId, s => {
840
+ s.ackMessage match {
841
+ case None => // Indicates a failure where we either never sent or never got ACK'd
842
+ promise.failure(new IOException (" sendMessageReliably failed without being ACK'd" ))
843
+ case Some (ackMessage) =>
844
+ if (ackMessage.hasError) {
845
+ promise.failure(
846
+ new IOException (" sendMessageReliably failed with ACK that signalled a remote error" ))
847
+ } else {
848
+ promise.success(ackMessage)
849
+ }
850
+ }
851
+ })
831
852
messageStatuses.synchronized {
832
853
messageStatuses += ((message.id, status))
833
854
}
834
855
sendMessage(connectionManagerId, message)
835
856
promise.future
836
857
}
837
858
838
- def sendMessageReliablySync (connectionManagerId : ConnectionManagerId ,
839
- message : Message ): Option [Message ] = {
840
- Await .result(sendMessageReliably(connectionManagerId, message), Duration .Inf )
841
- }
842
-
843
859
def onReceiveMessage (callback : (Message , ConnectionManagerId ) => Option [Message ]) {
844
860
onReceiveCallback = callback
845
861
}
@@ -862,6 +878,7 @@ private[spark] class ConnectionManager(
862
878
863
879
864
880
private [spark] object ConnectionManager {
881
+ import ExecutionContext .Implicits .global
865
882
866
883
def main (args : Array [String ]) {
867
884
val conf = new SparkConf
@@ -896,7 +913,7 @@ private[spark] object ConnectionManager {
896
913
897
914
(0 until count).map(i => {
898
915
val bufferMessage = Message .createBufferMessage(buffer.duplicate)
899
- manager.sendMessageReliablySync (manager.id, bufferMessage)
916
+ Await .result( manager.sendMessageReliably (manager.id, bufferMessage), Duration . Inf )
900
917
})
901
918
println(" --------------------------" )
902
919
println()
@@ -917,8 +934,10 @@ private[spark] object ConnectionManager {
917
934
val bufferMessage = Message .createBufferMessage(buffer.duplicate)
918
935
manager.sendMessageReliably(manager.id, bufferMessage)
919
936
}).foreach(f => {
920
- val g = Await .result(f, 1 second)
921
- if (! g.isDefined) println(" Failed" )
937
+ f.onFailure {
938
+ case e => println(" Failed due to " + e)
939
+ }
940
+ Await .ready(f, 1 second)
922
941
})
923
942
val finishTime = System .currentTimeMillis
924
943
@@ -952,8 +971,10 @@ private[spark] object ConnectionManager {
952
971
val bufferMessage = Message .createBufferMessage(buffers(count - 1 - i).duplicate)
953
972
manager.sendMessageReliably(manager.id, bufferMessage)
954
973
}).foreach(f => {
955
- val g = Await .result(f, 1 second)
956
- if (! g.isDefined) println(" Failed" )
974
+ f.onFailure {
975
+ case e => println(" Failed due to " + e)
976
+ }
977
+ Await .ready(f, 1 second)
957
978
})
958
979
val finishTime = System .currentTimeMillis
959
980
@@ -982,8 +1003,10 @@ private[spark] object ConnectionManager {
982
1003
val bufferMessage = Message .createBufferMessage(buffer.duplicate)
983
1004
manager.sendMessageReliably(manager.id, bufferMessage)
984
1005
}).foreach(f => {
985
- val g = Await .result(f, 1 second)
986
- if (! g.isDefined) println(" Failed" )
1006
+ f.onFailure {
1007
+ case e => println(" Failed due to " + e)
1008
+ }
1009
+ Await .ready(f, 1 second)
987
1010
})
988
1011
val finishTime = System .currentTimeMillis
989
1012
Thread .sleep(1000 )
0 commit comments