Skip to content

Commit 252d53c

Browse files
author
Andrew Or
committed
Clean up server error handling behavior further
This introduces a new response code for unknown protocol version in case a future client wants to retry using older versions. This commit also uses more specific error messages depending on how the request URL is malformed.
1 parent c643f64 commit 252d53c

File tree

1 file changed

+47
-30
lines changed

1 file changed

+47
-30
lines changed

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest
1919

2020
import java.io.{DataOutputStream, File}
2121
import java.net.InetSocketAddress
22-
import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}
22+
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
2323

2424
import scala.io.Source
2525

@@ -48,6 +48,15 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
4848
import StandaloneRestServer._
4949

5050
private var _server: Option[Server] = None
51+
private val basePrefix = s"/$PROTOCOL_VERSION/submissions"
52+
53+
// A mapping from servlets to the URL prefixes they are responsible for
54+
private val servletToPrefix = Map[StandaloneRestServlet, String](
55+
new SubmitRequestServlet(master) -> s"$basePrefix/create/*",
56+
new KillRequestServlet(master) -> s"$basePrefix/kill/*",
57+
new StatusRequestServlet(master) -> s"$basePrefix/status/*",
58+
new ErrorServlet -> "/"
59+
)
5160

5261
/** Start the server and return the bound port. */
5362
def start(): Int = {
@@ -58,28 +67,19 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
5867
}
5968

6069
/**
61-
* Set up the mapping from contexts to the appropriate servlets:
62-
* (1) submit requests should be directed to /create
63-
* (2) kill requests should be directed to /kill
64-
* (3) status requests should be directed to /status
70+
* Map the servlets to their corresponding contexts and attach them to a server.
6571
* Return a 2-tuple of the started server and the bound port.
6672
*/
6773
private def doStart(startPort: Int): (Server, Int) = {
6874
val server = new Server(new InetSocketAddress(host, requestedPort))
6975
val threadPool = new QueuedThreadPool
7076
threadPool.setDaemon(true)
7177
server.setThreadPool(threadPool)
72-
val pathPrefix = s"/$PROTOCOL_VERSION/submissions"
7378
val mainHandler = new ServletContextHandler
7479
mainHandler.setContextPath("/")
75-
mainHandler.addServlet(
76-
new ServletHolder(new SubmitRequestServlet(master)), s"$pathPrefix/create")
77-
mainHandler.addServlet(
78-
new ServletHolder(new KillRequestServlet(master)), s"$pathPrefix/kill/*")
79-
mainHandler.addServlet(
80-
new ServletHolder(new StatusRequestServlet(master)), s"$pathPrefix/status/*")
81-
mainHandler.addServlet(
82-
new ServletHolder(new ErrorServlet), "/")
80+
servletToPrefix.foreach { case (servlet, prefix) =>
81+
mainHandler.addServlet(new ServletHolder(servlet), prefix)
82+
}
8383
server.setHandler(mainHandler)
8484
server.start()
8585
val boundPort = server.getConnectors()(0).getLocalPort
@@ -93,6 +93,7 @@ private[spark] class StandaloneRestServer(master: Master, host: String, requeste
9393

9494
private object StandaloneRestServer {
9595
val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
96+
val SC_UNKNOWN_PROTOCOL_VERSION = 468
9697
}
9798

9899
/**
@@ -257,7 +258,6 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
257258
responseServlet: HttpServletResponse): Unit = {
258259
val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
259260
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
260-
.asInstanceOf[SubmitRestProtocolRequest]
261261
val responseMessage = handleSubmit(requestMessage, responseServlet)
262262
responseServlet.setContentType("application/json")
263263
responseServlet.setCharacterEncoding("utf-8")
@@ -268,8 +268,13 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
268268
out.close()
269269
}
270270

271+
/**
272+
* Handle a submit request by first validating the request message, then submitting the
273+
* application using the parameters specified in the message. If the message is not of
274+
* the expected type, return error to the client.
275+
*/
271276
private def handleSubmit(
272-
requestMessage: SubmitRestProtocolRequest,
277+
requestMessage: SubmitRestProtocolMessage,
273278
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
274279
// The response should have already been validated on the client.
275280
// In case this is not true, validate it ourselves to avoid potential NPEs.
@@ -293,8 +298,7 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
293298
submitResponse
294299
case unexpected =>
295300
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
296-
handleError(
297-
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
301+
handleError(s"Received message of unexpected type ${unexpected.messageType}.")
298302
}
299303
}
300304

@@ -366,23 +370,36 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
366370
*/
367371
private[spark] class ErrorServlet extends StandaloneRestServlet {
368372
private val expectedVersion = StandaloneRestServer.PROTOCOL_VERSION
373+
374+
/** Service a faulty request by returning an appropriate error message to the client. */
369375
protected override def service(
370376
request: HttpServletRequest,
371377
response: HttpServletResponse): Unit = {
378+
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
372379
val path = request.getPathInfo
373-
val parts = path.stripPrefix("/").split("/")
374-
if (parts.nonEmpty) {
375-
val version = parts.head
376-
if (version != expectedVersion) {
377-
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
378-
val error = handleError(s"Incompatible protocol version $version")
379-
sendResponse(error, response)
380-
return
380+
val parts = path.stripPrefix("/").split("/").toSeq
381+
var msg =
382+
parts match {
383+
case Nil =>
384+
// http://host:port/
385+
"Missing protocol version."
386+
case `expectedVersion` :: Nil =>
387+
// http://host:port/correct-version
388+
"Missing the /submissions prefix."
389+
case `expectedVersion` :: "submissions" :: Nil =>
390+
// http://host:port/correct-version/submissions
391+
"Missing an action: please specify one of /create, /kill, or /status."
392+
case unknownVersion :: _ =>
393+
// http://host:port/unknown-version/*
394+
// Use a special response code in case the client wants to retry with a different version
395+
response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION)
396+
s"Unknown protocol version '$unknownVersion'."
397+
case _ =>
398+
// never reached
399+
s"Malformed path $path."
381400
}
382-
}
383-
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
384-
val error = handleError(
385-
s"Unexpected path $path: Please submit requests through /$expectedVersion/submissions/")
401+
msg += s" Please submit requests through http://[host]:[port]/$expectedVersion/submissions/..."
402+
val error = handleError(msg)
386403
sendResponse(error, response)
387404
}
388405
}

0 commit comments

Comments
 (0)