Skip to content

Commit 45b2317

Browse files
committed
A standard RPC interface and An Akka implementation
1 parent 1768bd5 commit 45b2317

File tree

12 files changed

+1420
-79
lines changed

12 files changed

+1420
-79
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ import org.apache.spark.metrics.MetricsSystem
3434
import org.apache.spark.network.BlockTransferService
3535
import org.apache.spark.network.netty.NettyBlockTransferService
3636
import org.apache.spark.network.nio.NioBlockTransferService
37+
import org.apache.spark.rpc.akka.AkkaRpcEnv
38+
import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv}
3739
import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
38-
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor
40+
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
3941
import org.apache.spark.serializer.Serializer
4042
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
4143
import org.apache.spark.storage._
@@ -54,7 +56,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
5456
@DeveloperApi
5557
class SparkEnv (
5658
val executorId: String,
57-
val actorSystem: ActorSystem,
59+
val rpcEnv: RpcEnv,
5860
val serializer: Serializer,
5961
val closureSerializer: Serializer,
6062
val cacheManager: CacheManager,
@@ -71,6 +73,9 @@ class SparkEnv (
7173
val outputCommitCoordinator: OutputCommitCoordinator,
7274
val conf: SparkConf) extends Logging {
7375

76+
// TODO Remove actorSystem
77+
val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
78+
7479
private[spark] var isStopped = false
7580
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
7681

@@ -91,7 +96,8 @@ class SparkEnv (
9196
blockManager.master.stop()
9297
metricsSystem.stop()
9398
outputCommitCoordinator.stop()
94-
actorSystem.shutdown()
99+
rpcEnv.shutdown()
100+
95101
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
96102
// down, but let's call it anyway in case it gets fixed in a later release
97103
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
@@ -236,16 +242,15 @@ object SparkEnv extends Logging {
236242
val securityManager = new SecurityManager(conf)
237243

238244
// Create the ActorSystem for Akka and get the port it binds to.
239-
val (actorSystem, boundPort) = {
240-
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
241-
AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
242-
}
245+
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
246+
val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
247+
val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
243248

244249
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
245250
if (isDriver) {
246-
conf.set("spark.driver.port", boundPort.toString)
251+
conf.set("spark.driver.port", rpcEnv.address.port.toString)
247252
} else {
248-
conf.set("spark.executor.port", boundPort.toString)
253+
conf.set("spark.executor.port", rpcEnv.address.port.toString)
249254
}
250255

251256
// Create an instance of the class with the given name, possibly initializing it with our conf
@@ -290,6 +295,15 @@ object SparkEnv extends Logging {
290295
}
291296
}
292297

298+
def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = {
299+
if (isDriver) {
300+
logInfo("Registering " + name)
301+
rpcEnv.setupEndpoint(name, endpointCreator)
302+
} else {
303+
rpcEnv.setupDriverEndpointRef(name)
304+
}
305+
}
306+
293307
val mapOutputTracker = if (isDriver) {
294308
new MapOutputTrackerMaster(conf)
295309
} else {
@@ -377,13 +391,13 @@ object SparkEnv extends Logging {
377391
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
378392
new OutputCommitCoordinator(conf)
379393
}
380-
val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
381-
new OutputCommitCoordinatorActor(outputCommitCoordinator))
382-
outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)
394+
val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator",
395+
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
396+
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
383397

384398
val envInstance = new SparkEnv(
385399
executorId,
386-
actorSystem,
400+
rpcEnv,
387401
serializer,
388402
closureSerializer,
389403
cacheManager,

core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.File
2121

2222
import akka.actor._
2323

24+
import org.apache.spark.rpc.RpcEnv
2425
import org.apache.spark.{SecurityManager, SparkConf}
2526
import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
2627

@@ -32,9 +33,9 @@ object DriverWrapper {
3233
args.toList match {
3334
case workerUrl :: userJar :: mainClass :: extraArgs =>
3435
val conf = new SparkConf()
35-
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
36+
val rpcEnv = RpcEnv.create("Driver",
3637
Utils.localHostName(), 0, conf, new SecurityManager(conf))
37-
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
38+
rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl))
3839

3940
val currentLoader = Thread.currentThread.getContextClassLoader
4041
val userJarUrl = new File(userJar).toURI().toURL()
@@ -51,7 +52,7 @@ object DriverWrapper {
5152
val mainMethod = clazz.getMethod("main", classOf[Array[String]])
5253
mainMethod.invoke(null, extraArgs.toArray[String])
5354

54-
actorSystem.shutdown()
55+
rpcEnv.shutdown()
5556

5657
case _ =>
5758
System.err.println("Usage: DriverWrapper <workerUrl> <userJar> <driverMainClass> [options]")

core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,23 @@
1717

1818
package org.apache.spark.deploy.worker
1919

20-
import akka.actor.{Actor, Address, AddressFromURIString}
21-
import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent}
22-
2320
import org.apache.spark.Logging
2421
import org.apache.spark.deploy.DeployMessages.SendHeartbeat
25-
import org.apache.spark.util.ActorLogReceive
22+
import org.apache.spark.rpc._
2623

2724
/**
2825
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
2926
* Provides fate sharing between a worker and its associated child processes.
3027
*/
31-
private[spark] class WorkerWatcher(workerUrl: String)
32-
extends Actor with ActorLogReceive with Logging {
33-
34-
override def preStart() {
35-
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
28+
private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String)
29+
extends NetworkRpcEndpoint with Logging {
3630

31+
override def onStart() {
3732
logInfo(s"Connecting to worker $workerUrl")
38-
val worker = context.actorSelection(workerUrl)
39-
worker ! SendHeartbeat // need to send a message here to initiate connection
33+
if (!isTesting) {
34+
val worker = rpcEnv.setupEndpointRefByUrl(workerUrl)
35+
worker.send(SendHeartbeat) // need to send a message here to initiate connection
36+
}
4037
}
4138

4239
// Used to avoid shutting down JVM during tests
@@ -45,30 +42,37 @@ private[spark] class WorkerWatcher(workerUrl: String)
4542
private var isTesting = false
4643

4744
// Lets us filter events only from the worker's actor system
48-
private val expectedHostPort = AddressFromURIString(workerUrl).hostPort
49-
private def isWorker(address: Address) = address.hostPort == expectedHostPort
45+
private val expectedHostPort = new java.net.URI(workerUrl)
46+
private def isWorker(address: RpcAddress) = {
47+
expectedHostPort.getHost == address.host && expectedHostPort.getPort == address.port
48+
}
5049

5150
def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1)
5251

53-
override def receiveWithLogging = {
54-
case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
55-
logInfo(s"Successfully connected to $workerUrl")
52+
override def receive(sender: RpcEndpointRef) = {
53+
case e => logWarning(s"Received unexpected actor system event: $e")
54+
}
5655

57-
case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _)
58-
if isWorker(remoteAddress) =>
59-
// These logs may not be seen if the worker (and associated pipe) has died
60-
logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
61-
logError(s"Error was: $cause")
62-
exitNonZero()
56+
override def onConnected(remoteAddress: RpcAddress): Unit = {
57+
if (isWorker(remoteAddress)) {
58+
logInfo(s"Successfully connected to $workerUrl")
59+
}
60+
}
6361

64-
case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
62+
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
63+
if (isWorker(remoteAddress)) {
6564
// This log message will never be seen
6665
logError(s"Lost connection to worker actor $workerUrl. Exiting.")
6766
exitNonZero()
67+
}
68+
}
6869

69-
case e: AssociationEvent =>
70-
// pass through association events relating to other remote actor systems
71-
72-
case e => logWarning(s"Received unexpected actor system event: $e")
70+
override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
71+
if (isWorker(remoteAddress)) {
72+
// These logs may not be seen if the worker (and associated pipe) has died
73+
logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
74+
logError(s"Error was: $cause")
75+
exitNonZero()
76+
}
7377
}
7478
}

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
169169
driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
170170
name = "Executor")
171171
workerUrl.foreach { url =>
172-
env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
172+
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
173173
}
174174
env.actorSystem.awaitTermination()
175175
}

0 commit comments

Comments
 (0)