Skip to content

[SPARK-12267][Core]Store the remote RpcEnv address to send the correct disconnetion message #10261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ private[spark] class ApplicationInfo(
nextExecutorId = 0
removedExecutors = new ArrayBuffer[ExecutorDesc]
executorLimit = Integer.MAX_VALUE
appUIUrlAtHistoryServer = None
}

private def newExecutorId(useID: Option[Int] = None): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ private[deploy] object Worker extends Logging {
val conf = new SparkConf
val args = new WorkerArguments(argStrings, conf)
val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.masters, args.workDir)
args.memory, args.masters, args.workDir, conf = conf)
rpcEnv.awaitTermination()
}

Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ private[netty] class NettyRpcHandler(
// A variable to track whether we should dispatch the RemoteProcessConnected message.
private val clients = new ConcurrentHashMap[TransportClient, JBoolean]()

// A variable to track the remote RpcEnv addresses of all clients
private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()

override def receive(
client: TransportClient,
message: ByteBuffer,
Expand Down Expand Up @@ -580,6 +583,12 @@ private[netty] class NettyRpcHandler(
// Create a new message with the socket address of the client as the sender.
RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
} else {
// The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for
// the listening address
val remoteEnvAddress = requestMessage.senderAddress
if (remoteAddresses.putIfAbsent(clientAddr, remoteEnvAddress) == null) {
dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to concern about multiple network messages with different addresses since the codes already handle the bad addresses.

}
requestMessage
}
}
Expand All @@ -591,6 +600,12 @@ private[netty] class NettyRpcHandler(
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr))
// If the remove RpcEnv listens to some address, we should also fire a
// RemoteProcessConnectionError for the remote RpcEnv listening address
val remoteEnvAddress = remoteAddresses.get(clientAddr)
if (remoteEnvAddress != null) {
dispatcher.postToAll(RemoteProcessConnectionError(cause, remoteEnvAddress))
}
} else {
// If the channel is closed before connecting, its remoteAddress will be null.
// See java.net.Socket.getRemoteSocketAddress
Expand All @@ -606,6 +621,12 @@ private[netty] class NettyRpcHandler(
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
dispatcher.postToAll(RemoteProcessDisconnected(clientAddr))
val remoteEnvAddress = remoteAddresses.remove(clientAddr)
// If the remove RpcEnv listens to some address, we should also fire a
// RemoteProcessDisconnected for the remote RpcEnv listening address
if (remoteEnvAddress != null) {
dispatcher.postToAll(RemoteProcessDisconnected(remoteEnvAddress))
}
} else {
// If the channel is closed before connecting, its remoteAddress will be null. In this case,
// we can ignore it since we don't fire "Associated".
Expand Down
42 changes: 42 additions & 0 deletions core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
}
}

test("network events between non-client-mode RpcEnvs") {
val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)]
env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint {
override val rpcEnv = env

override def receive: PartialFunction[Any, Unit] = {
case "hello" =>
case m => events += "receive" -> m
}

override def onConnected(remoteAddress: RpcAddress): Unit = {
events += "onConnected" -> remoteAddress
}

override def onDisconnected(remoteAddress: RpcAddress): Unit = {
events += "onDisconnected" -> remoteAddress
}

override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
events += "onNetworkError" -> remoteAddress
}

})

val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false)
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(
"local", env.address, "network-events-non-client")
val remoteAddress = anotherEnv.address
rpcEndpointRef.send("hello")
eventually(timeout(5 seconds), interval(5 millis)) {
assert(events.contains(("onConnected", remoteAddress)))
}

anotherEnv.shutdown()
anotherEnv.awaitTermination()
eventually(timeout(5 seconds), interval(5 millis)) {
assert(events.contains(("onConnected", remoteAddress)))
assert(events.contains(("onDisconnected", remoteAddress)))
}
}

test("sendWithReply: unserializable error") {
env.setupEndpoint("sendWithReply-unserializable-error", new RpcEndpoint {
override val rpcEnv = env
Expand Down