@@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest
19
19
20
20
import java .io .{DataOutputStream , File }
21
21
import java .net .InetSocketAddress
22
- import javax .servlet .http .{HttpServlet , HttpServletResponse , HttpServletRequest }
22
+ import javax .servlet .http .{HttpServlet , HttpServletRequest , HttpServletResponse }
23
23
24
24
import scala .io .Source
25
25
@@ -48,6 +48,15 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
48
48
import StandaloneRestServer ._
49
49
50
50
private var _server : Option [Server ] = None
51
+ private val basePrefix = s " / $PROTOCOL_VERSION/submissions "
52
+
53
+ // A mapping from servlets to the URL prefixes they are responsible for
54
+ private val servletToPrefix = Map [StandaloneRestServlet , String ](
55
+ new SubmitRequestServlet (master) -> s " $basePrefix/create/* " ,
56
+ new KillRequestServlet (master) -> s " $basePrefix/kill/* " ,
57
+ new StatusRequestServlet (master) -> s " $basePrefix/status/* " ,
58
+ new ErrorServlet -> " /"
59
+ )
51
60
52
61
/** Start the server and return the bound port. */
53
62
def start (): Int = {
@@ -58,28 +67,19 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
58
67
}
59
68
60
69
/**
61
- * Set up the mapping from contexts to the appropriate servlets:
62
- * (1) submit requests should be directed to /create
63
- * (2) kill requests should be directed to /kill
64
- * (3) status requests should be directed to /status
70
+ * Map the servlets to their corresponding contexts and attach them to a server.
65
71
* Return a 2-tuple of the started server and the bound port.
66
72
*/
67
73
private def doStart (startPort : Int ): (Server , Int ) = {
68
74
val server = new Server (new InetSocketAddress (host, requestedPort))
69
75
val threadPool = new QueuedThreadPool
70
76
threadPool.setDaemon(true )
71
77
server.setThreadPool(threadPool)
72
- val pathPrefix = s " / $PROTOCOL_VERSION/submissions "
73
78
val mainHandler = new ServletContextHandler
74
79
mainHandler.setContextPath(" /" )
75
- mainHandler.addServlet(
76
- new ServletHolder (new SubmitRequestServlet (master)), s " $pathPrefix/create " )
77
- mainHandler.addServlet(
78
- new ServletHolder (new KillRequestServlet (master)), s " $pathPrefix/kill/* " )
79
- mainHandler.addServlet(
80
- new ServletHolder (new StatusRequestServlet (master)), s " $pathPrefix/status/* " )
81
- mainHandler.addServlet(
82
- new ServletHolder (new ErrorServlet ), " /" )
80
+ servletToPrefix.foreach { case (servlet, prefix) =>
81
+ mainHandler.addServlet(new ServletHolder (servlet), prefix)
82
+ }
83
83
server.setHandler(mainHandler)
84
84
server.start()
85
85
val boundPort = server.getConnectors()(0 ).getLocalPort
@@ -93,6 +93,7 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
93
93
94
94
private object StandaloneRestServer {
95
95
val PROTOCOL_VERSION = StandaloneRestClient .PROTOCOL_VERSION
96
+ val SC_UNKNOWN_PROTOCOL_VERSION = 468
96
97
}
97
98
98
99
/**
@@ -257,7 +258,6 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
257
258
responseServlet : HttpServletResponse ): Unit = {
258
259
val requestMessageJson = Source .fromInputStream(requestServlet.getInputStream).mkString
259
260
val requestMessage = SubmitRestProtocolMessage .fromJson(requestMessageJson)
260
- .asInstanceOf [SubmitRestProtocolRequest ]
261
261
val responseMessage = handleSubmit(requestMessage, responseServlet)
262
262
responseServlet.setContentType(" application/json" )
263
263
responseServlet.setCharacterEncoding(" utf-8" )
@@ -268,8 +268,13 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
268
268
out.close()
269
269
}
270
270
271
+ /**
272
+ * Handle a submit request by first validating the request message, then submitting the
273
+ * application using the parameters specified in the message. If the message is not of
274
+ * the expected type, return error to the client.
275
+ */
271
276
private def handleSubmit (
272
- requestMessage : SubmitRestProtocolRequest ,
277
+ requestMessage : SubmitRestProtocolMessage ,
273
278
responseServlet : HttpServletResponse ): SubmitRestProtocolResponse = {
274
279
// The response should have already been validated on the client.
275
280
// In case this is not true, validate it ourselves to avoid potential NPEs.
@@ -293,8 +298,7 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
293
298
submitResponse
294
299
case unexpected =>
295
300
responseServlet.setStatus(HttpServletResponse .SC_BAD_REQUEST )
296
- handleError(
297
- s " Received message of unexpected type ${Utils .getFormattedClassName(unexpected)}. " )
301
+ handleError(s " Received message of unexpected type ${unexpected.messageType}. " )
298
302
}
299
303
}
300
304
@@ -366,23 +370,36 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
366
370
*/
367
371
private [spark] class ErrorServlet extends StandaloneRestServlet {
368
372
private val expectedVersion = StandaloneRestServer .PROTOCOL_VERSION
373
+
374
+ /** Service a faulty request by returning an appropriate error message to the client. */
369
375
protected override def service (
370
376
request : HttpServletRequest ,
371
377
response : HttpServletResponse ): Unit = {
378
+ response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
372
379
val path = request.getPathInfo
373
- val parts = path.stripPrefix(" /" ).split(" /" )
374
- if (parts.nonEmpty) {
375
- val version = parts.head
376
- if (version != expectedVersion) {
377
- response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
378
- val error = handleError(s " Incompatible protocol version $version" )
379
- sendResponse(error, response)
380
- return
380
+ val parts = path.stripPrefix(" /" ).split(" /" ).toSeq
381
+ var msg =
382
+ parts match {
383
+ case Nil =>
384
+ // http://host:port/
385
+ " Missing protocol version."
386
+ case `expectedVersion` :: Nil =>
387
+ // http://host:port/correct-version
388
+ " Missing the /submissions prefix."
389
+ case `expectedVersion` :: " submissions" :: Nil =>
390
+ // http://host:port/correct-version/submissions
391
+ " Missing an action: please specify one of /create, /kill, or /status."
392
+ case unknownVersion :: _ =>
393
+ // http://host:port/unknown-version/*
394
+ // Use a special response code in case the client wants to retry with a different version
395
+ response.setStatus(StandaloneRestServer .SC_UNKNOWN_PROTOCOL_VERSION )
396
+ s " Unknown protocol version ' $unknownVersion'. "
397
+ case _ =>
398
+ // never reached
399
+ s " Malformed path $path. "
381
400
}
382
- }
383
- response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
384
- val error = handleError(
385
- s " Unexpected path $path: Please submit requests through / $expectedVersion/submissions/ " )
401
+ msg += s " Please submit requests through http://[host]:[port]/ $expectedVersion/submissions/... "
402
+ val error = handleError(msg)
386
403
sendResponse(error, response)
387
404
}
388
405
}
0 commit comments