Skip to content

Commit e72afdb

Browse files
committed
Some refactoring to make cluster scheduler pluggable.
1 parent 5d1a887 commit e72afdb

34 files changed

+718
-509
lines changed

core/src/main/scala/spark/PairRDDFunctions.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
307307
val jobtrackerID = formatter.format(new Date())
308308
val stageId = self.id
309309
def writeShard(context: spark.TaskContext, iter: Iterator[(K,V)]): Int = {
310+
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
311+
// around by taking a mod. We expect that no task will be attempted 2 billion times.
312+
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
310313
/* "reduce task" <split #> <attempt # = spark task #> */
311314
val attemptId = new TaskAttemptID(jobtrackerID,
312-
stageId, false, context.splitId, context.attemptId)
315+
stageId, false, context.splitId, attemptNumber)
313316
val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId)
314317
val format = outputFormatClass.newInstance
315318
val committer = format.getOutputCommitter(hadoopContext)
@@ -371,7 +374,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
371374
writer.preSetup()
372375

373376
def writeToFile(context: TaskContext, iter: Iterator[(K,V)]) {
374-
writer.setup(context.stageId, context.splitId, context.attemptId)
377+
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
378+
// around by taking a mod. We expect that no task will be attempted 2 billion times.
379+
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
380+
381+
writer.setup(context.stageId, context.splitId, attemptNumber)
375382
writer.open()
376383

377384
var count = 0

core/src/main/scala/spark/SparkContext.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ import spark.scheduler.ShuffleMapTask
4141
import spark.scheduler.DAGScheduler
4242
import spark.scheduler.TaskScheduler
4343
import spark.scheduler.local.LocalScheduler
44+
import spark.scheduler.cluster.ClusterScheduler
4445
import spark.scheduler.mesos.MesosScheduler
45-
import spark.scheduler.mesos.CoarseMesosScheduler
4646
import spark.storage.BlockManagerMaster
4747

4848
class SparkContext(
@@ -89,11 +89,17 @@ class SparkContext(
8989
new LocalScheduler(threads.toInt, maxFailures.toInt)
9090
case _ =>
9191
MesosNativeLibrary.load()
92+
val sched = new ClusterScheduler(this)
93+
val schedContext = new MesosScheduler(sched, this, master, frameworkName)
94+
sched.initialize(schedContext)
95+
sched
96+
/*
9297
if (System.getProperty("spark.mesos.coarse", "false") == "true") {
9398
new CoarseMesosScheduler(this, master, frameworkName)
9499
} else {
95100
new MesosScheduler(this, master, frameworkName)
96101
}
102+
*/
97103
}
98104
}
99105
taskScheduler.start()
@@ -272,11 +278,6 @@ class SparkContext(
272278
logInfo("Successfully stopped SparkContext")
273279
}
274280

275-
// Wait for the scheduler to be registered with the cluster manager
276-
def waitForRegister() {
277-
taskScheduler.waitForRegister()
278-
}
279-
280281
// Get Spark's home location from either a value set through the constructor,
281282
// or the spark.home Java property, or the SPARK_HOME environment variable
282283
// (in that order of preference). If neither of these is set, return None.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
package spark
22

3-
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
3+
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package spark
2+
3+
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
4+
5+
object TaskState
6+
extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
7+
8+
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
9+
10+
type TaskState = Value
11+
12+
def isFinished(state: TaskState) = Seq(FINISHED, FAILED, LOST).contains(state)
13+
14+
def toMesos(state: TaskState): MesosTaskState = state match {
15+
case LAUNCHING => MesosTaskState.TASK_STARTING
16+
case RUNNING => MesosTaskState.TASK_RUNNING
17+
case FINISHED => MesosTaskState.TASK_FINISHED
18+
case FAILED => MesosTaskState.TASK_FAILED
19+
case KILLED => MesosTaskState.TASK_KILLED
20+
case LOST => MesosTaskState.TASK_LOST
21+
}
22+
23+
def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match {
24+
case MesosTaskState.TASK_STAGING => LAUNCHING
25+
case MesosTaskState.TASK_STARTING => LAUNCHING
26+
case MesosTaskState.TASK_RUNNING => RUNNING
27+
case MesosTaskState.TASK_FINISHED => FINISHED
28+
case MesosTaskState.TASK_FAILED => FAILED
29+
case MesosTaskState.TASK_KILLED => KILLED
30+
case MesosTaskState.TASK_LOST => LOST
31+
}
32+
}

core/src/main/scala/spark/Utils.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import scala.io.Source
1313
* Various utility methods used by Spark.
1414
*/
1515
object Utils {
16+
/** Serialize an object using Java serialization */
1617
def serialize[T](o: T): Array[Byte] = {
1718
val bos = new ByteArrayOutputStream()
1819
val oos = new ObjectOutputStream(bos)
@@ -21,12 +22,14 @@ object Utils {
2122
return bos.toByteArray
2223
}
2324

25+
/** Deserialize an object using Java serialization */
2426
def deserialize[T](bytes: Array[Byte]): T = {
2527
val bis = new ByteArrayInputStream(bytes)
2628
val ois = new ObjectInputStream(bis)
2729
return ois.readObject.asInstanceOf[T]
2830
}
2931

32+
/** Deserialize an object using Java serialization and the given ClassLoader */
3033
def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
3134
val bis = new ByteArrayInputStream(bytes)
3235
val ois = new ObjectInputStream(bis) {
@@ -106,6 +109,13 @@ object Utils {
106109
}
107110
}
108111

112+
/** Copy a file on the local file system */
113+
def copyFile(source: File, dest: File) {
114+
val in = new FileInputStream(source)
115+
val out = new FileOutputStream(dest)
116+
copyStream(in, out, true)
117+
}
118+
109119
/**
110120
* Shuffle the elements of a collection into a random order, returning the
111121
* result in a new collection. Unlike scala.util.Random.shuffle, this method

core/src/main/scala/spark/deploy/DeployMessage.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package spark.deploy
22

3+
import spark.deploy.ExecutorState.ExecutorState
4+
35
sealed trait DeployMessage extends Serializable
46

57
// Worker to Master
@@ -10,8 +12,7 @@ case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memor
1012
case class ExecutorStateChanged(
1113
jobId: String,
1214
execId: Int,
13-
state:
14-
ExecutorState.Value,
15+
state: ExecutorState,
1516
message: Option[String])
1617
extends DeployMessage
1718

@@ -38,7 +39,7 @@ case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
3839

3940
case class RegisteredJob(jobId: String) extends DeployMessage
4041
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
41-
case class ExecutorUpdated(id: Int, state: ExecutorState.Value, message: Option[String])
42+
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String])
4243
case class JobKilled(message: String)
4344

4445
// Internal message in Client

core/src/main/scala/spark/deploy/ExecutorState.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,7 @@ object ExecutorState
55

66
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
77

8-
def isFinished(state: Value): Boolean = (state == KILLED || state == FAILED || state == LOST)
8+
type ExecutorState = Value
9+
10+
def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state)
911
}

core/src/main/scala/spark/deploy/JobDescription.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ class JobDescription(
44
val name: String,
55
val cores: Int,
66
val memoryPerSlave: Int,
7-
val fileUrls: Seq[String],
87
val command: Command)
98
extends Serializable {
109

core/src/main/scala/spark/deploy/client/TestClient.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ object TestClient {
2424
def main(args: Array[String]) {
2525
val url = args(0)
2626
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress(), 0)
27-
val desc = new JobDescription("TestClient", 1, 512, Seq(),
28-
Command("spark.deploy.client.TestExecutor", Seq(), Map()))
27+
val desc = new JobDescription(
28+
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
2929
val listener = new TestListener
3030
val client = new Client(actorSystem, url, desc, listener)
3131
client.start()

core/src/main/scala/spark/deploy/client/TestExecutor.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@ package spark.deploy.client
33
object TestExecutor {
44
def main(args: Array[String]) {
55
println("Hello world!")
6+
while (true) {
7+
Thread.sleep(1000)
8+
}
69
}
710
}

0 commit comments

Comments
 (0)