17
17
18
18
package org .apache .spark .network
19
19
20
+ import java .io .IOException
20
21
import java .nio ._
21
22
import java .nio .channels ._
22
23
import java .nio .channels .spi ._
@@ -41,16 +42,26 @@ import org.apache.spark.util.{SystemClock, Utils}
41
42
private [spark] class ConnectionManager (port : Int , conf : SparkConf ,
42
43
securityManager : SecurityManager ) extends Logging {
43
44
45
+ /**
46
+ * Used by sendMessageReliably to track messages being sent.
47
+ * @param message the message that was sent
48
+ * @param connectionManagerId the connection manager that sent this message
49
+ * @param completionHandler callback that's invoked when the send has completed or failed
50
+ */
44
51
class MessageStatus (
45
52
val message : Message ,
46
53
val connectionManagerId : ConnectionManagerId ,
47
54
completionHandler : MessageStatus => Unit ) {
48
55
56
+ /** This is non-None if message has been ack'd */
49
57
var ackMessage : Option [Message ] = None
50
- var attempted = false
51
- var acked = false
52
58
53
- def markDone () { completionHandler(this ) }
59
+ def markDone (ackMessage : Option [Message ]) {
60
+ this .synchronized {
61
+ this .ackMessage = ackMessage
62
+ completionHandler(this )
63
+ }
64
+ }
54
65
}
55
66
56
67
private val selector = SelectorProvider .provider.openSelector()
@@ -434,11 +445,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
434
445
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
435
446
.foreach(status => {
436
447
logInfo(" Notifying " + status)
437
- status.synchronized {
438
- status.attempted = true
439
- status.acked = false
440
- status.markDone()
441
- }
448
+ status.markDone(None )
442
449
})
443
450
444
451
messageStatuses.retain((i, status) => {
@@ -467,11 +474,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
467
474
for (s <- messageStatuses.values
468
475
if s.connectionManagerId == sendingConnectionManagerId) {
469
476
logInfo(" Notifying " + s)
470
- s.synchronized {
471
- s.attempted = true
472
- s.acked = false
473
- s.markDone()
474
- }
477
+ s.markDone(None )
475
478
}
476
479
477
480
messageStatuses.retain((i, status) => {
@@ -539,13 +542,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
539
542
val securityMsgResp = SecurityMessage .fromResponse(replyToken,
540
543
securityMsg.getConnectionId.toString)
541
544
val message = securityMsgResp.toBufferMessage
542
- if (message == null ) throw new Exception (" Error creating security message" )
545
+ if (message == null ) throw new IOException (" Error creating security message" )
543
546
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
544
547
} catch {
545
548
case e : Exception => {
546
549
logError(" Error handling sasl client authentication" , e)
547
550
waitingConn.close()
548
- throw new Exception (" Error evaluating sasl response: " + e)
551
+ throw new IOException (" Error evaluating sasl response: " + e)
549
552
}
550
553
}
551
554
}
@@ -653,12 +656,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
653
656
}
654
657
}
655
658
}
656
- sentMessageStatus.synchronized {
657
- sentMessageStatus.ackMessage = Some (message)
658
- sentMessageStatus.attempted = true
659
- sentMessageStatus.acked = true
660
- sentMessageStatus.markDone()
661
- }
659
+ sentMessageStatus.markDone(Some (message))
662
660
} else {
663
661
var ackMessage : Option [Message ] = None
664
662
try {
@@ -681,7 +679,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
681
679
}
682
680
} catch {
683
681
case e : Exception => {
684
- logError(s " Exception was thrown during processing message " , e)
682
+ logError(s " Exception was thrown while processing message " , e)
685
683
val m = Message .createBufferMessage(bufferMessage.id)
686
684
m.hasError = true
687
685
ackMessage = Some (m)
@@ -802,11 +800,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
802
800
case Some (msgStatus) => {
803
801
messageStatuses -= message.id
804
802
logInfo(" Notifying " + msgStatus.connectionManagerId)
805
- msgStatus.synchronized {
806
- msgStatus.attempted = true
807
- msgStatus.acked = false
808
- msgStatus.markDone()
809
- }
803
+ msgStatus.markDone(None )
810
804
}
811
805
case None => {
812
806
logError(" no messageStatus for failed message id: " + message.id)
@@ -825,11 +819,28 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
825
819
selector.wakeup()
826
820
}
827
821
822
+ /**
823
+ * Send a message and block until an acknowldgment is received or an error occurs.
824
+ * @param connectionManagerId the message's destination
825
+ * @param message the message being sent
826
+ * @return a Future that either returns the acknowledgment message or captures an exception.
827
+ */
828
828
def sendMessageReliably (connectionManagerId : ConnectionManagerId , message : Message )
829
- : Future [Option [Message ]] = {
830
- val promise = Promise [Option [Message ]]
831
- val status = new MessageStatus (
832
- message, connectionManagerId, s => promise.success(s.ackMessage))
829
+ : Future [Message ] = {
830
+ val promise = Promise [Message ]()
831
+ val status = new MessageStatus (message, connectionManagerId, s => {
832
+ s.ackMessage match {
833
+ case None => // Indicates a failure where we either never sent or never got ACK'd
834
+ promise.failure(new IOException (" sendMessageReliably failed without being ACK'd" ))
835
+ case Some (ackMessage) =>
836
+ if (ackMessage.hasError) {
837
+ promise.failure(
838
+ new IOException (" sendMessageReliably failed with ACK that signalled a remote error" ))
839
+ } else {
840
+ promise.success(ackMessage)
841
+ }
842
+ }
843
+ })
833
844
messageStatuses.synchronized {
834
845
messageStatuses += ((message.id, status))
835
846
}
@@ -838,7 +849,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
838
849
}
839
850
840
851
def sendMessageReliablySync (connectionManagerId : ConnectionManagerId ,
841
- message : Message ): Option [ Message ] = {
852
+ message : Message ): Message = {
842
853
Await .result(sendMessageReliably(connectionManagerId, message), Duration .Inf )
843
854
}
844
855
@@ -864,6 +875,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
864
875
865
876
866
877
private [spark] object ConnectionManager {
878
+ import ExecutionContext .Implicits .global
867
879
868
880
def main (args : Array [String ]) {
869
881
val conf = new SparkConf
@@ -919,8 +931,10 @@ private[spark] object ConnectionManager {
919
931
val bufferMessage = Message .createBufferMessage(buffer.duplicate)
920
932
manager.sendMessageReliably(manager.id, bufferMessage)
921
933
}).foreach(f => {
922
- val g = Await .result(f, 1 second)
923
- if (! g.isDefined) println(" Failed" )
934
+ f.onFailure {
935
+ case e => println(" Failed due to " + e)
936
+ }
937
+ Await .ready(f, 1 second)
924
938
})
925
939
val finishTime = System .currentTimeMillis
926
940
@@ -954,8 +968,10 @@ private[spark] object ConnectionManager {
954
968
val bufferMessage = Message .createBufferMessage(buffers(count - 1 - i).duplicate)
955
969
manager.sendMessageReliably(manager.id, bufferMessage)
956
970
}).foreach(f => {
957
- val g = Await .result(f, 1 second)
958
- if (! g.isDefined) println(" Failed" )
971
+ f.onFailure {
972
+ case e => println(" Failed due to " + e)
973
+ }
974
+ Await .ready(f, 1 second)
959
975
})
960
976
val finishTime = System .currentTimeMillis
961
977
@@ -984,8 +1000,10 @@ private[spark] object ConnectionManager {
984
1000
val bufferMessage = Message .createBufferMessage(buffer.duplicate)
985
1001
manager.sendMessageReliably(manager.id, bufferMessage)
986
1002
}).foreach(f => {
987
- val g = Await .result(f, 1 second)
988
- if (! g.isDefined) println(" Failed" )
1003
+ f.onFailure {
1004
+ case e => println(" Failed due to " + e)
1005
+ }
1006
+ Await .ready(f, 1 second)
989
1007
})
990
1008
val finishTime = System .currentTimeMillis
991
1009
Thread .sleep(1000 )
0 commit comments