@@ -53,39 +53,54 @@ class JavaExecutor extends Executor {
53
53
val httpClient = new HttpClient ()
54
54
// HACK: remove after the mesos teams makes our work directory writable
55
55
val baseDir = new File (" /mnt" )
56
+ val jarCache = new File (baseDir, " jarCache" )
56
57
58
+ if (! jarCache.exists)
59
+ jarCache.mkdir
60
+
61
+ /* Keep a handle to all tasks that are running so we can kill it if needed later */
57
62
case class RunningTask (proc : Process , stdout : StreamTailer , stderr : StreamTailer )
58
63
val runningTasks = new scala.collection.mutable.HashMap [Int , RunningTask ]
59
64
60
- protected def loadClasspath (classSources : Seq [ClassSource ]): String = classSources.pmap {
65
+ protected def loadClasspath (classSources : Seq [ClassSource ]): String = classSources.map {
61
66
case ServerSideJar (path) => path
62
- // TODO: Cache these jars!
63
67
case S3CachedJar (urlString) => {
64
- val method = new GetMethod (urlString)
65
- logger.info(" Downloading %s" , urlString)
66
- httpClient.executeMethod(method)
67
- val instream = method.getResponseBodyAsStream
68
- val outfile = File .createTempFile(" deploylibS3CachedJar" , " .jar" )
69
- val outstream = new FileOutputStream (outfile)
70
-
71
- var x = instream.read
72
- while (x != - 1 ) {
73
- outstream.write(x)
74
- x = instream.read
68
+ val jarUrl = new URL (urlString)
69
+
70
+ // Note: this makes the assumption that the name of the file is the Md5 hash of the file.
71
+ var jarMd5 = new File (jarUrl.getFile).getName
72
+ val cachedJar = new File (jarCache, jarMd5)
73
+
74
+ // TODO: Locks incase there are multiple executors on a machine
75
+ if ((! cachedJar.exists) || ! (Util .md5(cachedJar) equals jarMd5)) {
76
+ val method = new GetMethod (urlString)
77
+ logger.info(" Downloading %s" , urlString)
78
+ httpClient.executeMethod(method)
79
+ val instream = method.getResponseBodyAsStream
80
+ val outstream = new FileOutputStream (cachedJar)
81
+
82
+ var x = instream.read
83
+ while (x != - 1 ) {
84
+ outstream.write(x)
85
+ x = instream.read
86
+ }
87
+ instream.close
88
+ outstream.close
89
+ logger.info(" Download of %s complete" , urlString)
75
90
}
76
- instream.close
77
- outstream.close
78
- outfile.toString
79
- logger.info(" Download of %s complete" , urlString)
91
+ cachedJar.getCanonicalPath()
80
92
}
81
93
}.mkString(" :" )
82
94
83
95
override def launchTask (d : ExecutorDriver , taskDesc : TaskDescription ): Unit = {
84
- val launchDelay = Random .nextInt(30 * 1000 )
96
+ val taskId = taskDesc.getTaskId // Note: use this because you can't hold on to taskDesc after this function exits.
97
+ d.sendStatusUpdate(new TaskStatus (taskId, TaskState .TASK_STARTING , new Array [Byte ](0 )))
98
+
99
+ val launchDelay = Random .nextInt(10 * 1000 )
85
100
logger.info(" Delaying startup %dms to avoid overloading zookeeper" , launchDelay)
86
101
Thread .sleep(launchDelay)
87
102
88
- logger.info(" Starting task" + taskDesc.getTaskId() )
103
+ logger.info(" Starting task" + taskId )
89
104
val tempDir = File .createTempFile(" deploylib" , " mesosJavaExecutorWorkingDir" , baseDir)
90
105
tempDir.delete()
91
106
tempDir.mkdir()
@@ -104,29 +119,28 @@ class JavaExecutor extends Executor {
104
119
processDescription.mainclass) ++ processDescription.args
105
120
106
121
logger.info(" Execing: " + cmdLine.mkString(" " ))
107
- d.sendStatusUpdate(new TaskStatus (taskDesc.getTaskId, TaskState .TASK_STARTING , new Array [Byte ](0 )))
108
122
val proc = Runtime .getRuntime().exec(cmdLine.filter(_.size != 0 ).toArray, Array [String ](), tempDir)
109
123
val stdout = new StreamTailer (proc.getInputStream())
110
124
val stderr = new StreamTailer (proc.getErrorStream())
111
125
def output = List (cmdLine, processDescription, " ===stdout===" , stdout.tail, " ===stderr===" , stderr.tail).mkString(" \n " ).getBytes
112
126
113
- val taskThread = new Thread (" Task " + taskDesc.getTaskId + " Monitor" ) {
127
+ val taskThread = new Thread (" Task " + taskId + " Monitor" ) {
114
128
override def run () = {
115
129
val result = proc.waitFor()
116
130
val finalTaskState = result match {
117
131
case 0 => TaskState .TASK_FINISHED
118
132
case _ => TaskState .TASK_FAILED
119
133
}
120
- d.sendStatusUpdate(new TaskStatus (taskDesc.getTaskId , finalTaskState, output))
121
- logger.info(" Cleaning up working directory %s for %d" , tempDir, taskDesc.getTaskId() )
134
+ d.sendStatusUpdate(new TaskStatus (taskId , finalTaskState, output))
135
+ logger.info(" Cleaning up working directory %s for %d" , tempDir, taskId )
122
136
deleteRecursive(tempDir)
123
- logger.info(" Task %d" , taskDesc.getTaskId() )
137
+ logger.info(" Done cleaning up after Task %d" , taskId )
124
138
}
125
139
}
126
- taskThread.run ()
127
- runningTasks += ((taskDesc.getTaskId() , RunningTask (proc, stdout, stderr)))
128
- d.sendStatusUpdate(new TaskStatus (taskDesc.getTaskId , TaskState .TASK_RUNNING , output))
129
- logger.info(" Task %d started" , taskDesc.getTaskId() )
140
+ taskThread.start ()
141
+ runningTasks += ((taskId , RunningTask (proc, stdout, stderr)))
142
+ d.sendStatusUpdate(new TaskStatus (taskId , TaskState .TASK_RUNNING , output))
143
+ logger.info(" Task %d started" , taskId )
130
144
}
131
145
132
146
override def killTask (d : ExecutorDriver , taskId : Int ): Unit = {
0 commit comments