Skip to content

Commit

Permalink
Version the protocol and include it in REST URL
Browse files Browse the repository at this point in the history
Now the REST URLs look like this:
http://host:port/v1/submissions/create
http://host:port/v1/submissions/kill/driver_123
http://host:port/v1/submissions/status/driver_123
  • Loading branch information
Andrew Or committed Feb 4, 2015
1 parent 721819f commit f98660b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ import org.apache.spark.deploy.SparkSubmitArguments
* currently used for cluster mode only.
*
* The specific request sent to the server depends on the action as follows:
* (1) submit - POST to http://.../submissions/create
* (2) kill - POST http://.../submissions/kill/[submissionId]
* (3) status - GET http://.../submissions/status/[submissionId]
* (1) submit - POST to /submissions/create
* (2) kill - POST /submissions/kill/[submissionId]
* (3) status - GET /submissions/status/[submissionId]
*
* In the case of (1), parameters are posted in the HTTP body in the form of JSON fields.
* Otherwise, the URL fully specifies the intended action of the client.
*
* Additionally, the base URL includes the version of the protocol. For instance:
* http://1.2.3.4:6066/v1/submissions/create. Since the protocol is expected to be stable
* across Spark versions, existing fields cannot be added or removed. In the rare event that
* backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2).
* The client and the server must communicate on the same version of the protocol.
*/
private[spark] class StandaloneRestClient extends Logging {
import StandaloneRestClient._
Expand Down Expand Up @@ -147,20 +153,25 @@ private[spark] class StandaloneRestClient extends Logging {

/** Return the REST URL for creating a new submission. */
private def getSubmitUrl(master: String): URL = {
val baseUrl = master.stripPrefix("spark://")
new URL(s"http://$baseUrl/submissions/create")
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/submissions/create")
}

/** Return the REST URL for killing an existing submission. */
private def getKillUrl(master: String, submissionId: String): URL = {
val baseUrl = master.stripPrefix("spark://")
new URL(s"http://$baseUrl/submissions/kill/$submissionId")
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/submissions/kill/$submissionId")
}

/** Return the REST URL for requesting the status of an existing submission. */
private def getStatusUrl(master: String, submissionId: String): URL = {
val baseUrl = master.stripPrefix("spark://")
new URL(s"http://$baseUrl/submissions/status/$submissionId")
val baseUrl = getBaseUrl(master)
new URL(s"$baseUrl/submissions/status/$submissionId")
}

/** Return the base URL for communicating with the server, including the protocol version. */
private def getBaseUrl(master: String): String = {
"http://" + master.stripPrefix("spark://").stripSuffix("/") + "/" + PROTOCOL_VERSION
}

/** Throw an exception if this is not standalone mode. */
Expand Down Expand Up @@ -261,4 +272,5 @@ private[spark] class StandaloneRestClient extends Logging {
private object StandaloneRestClient {
val REPORT_DRIVER_STATUS_INTERVAL = 1000
val REPORT_DRIVER_STATUS_MAX_TRIES = 10
val PROTOCOL_VERSION = "v1"
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ private[spark] class StandaloneRestServer(
requestedPort: Int)
extends Logging {

import StandaloneRestServer._

private var _server: Option[Server] = None

/** Start the server and return the bound port. */
Expand All @@ -67,11 +69,17 @@ private[spark] class StandaloneRestServer(
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
val pathPrefix = s"/$PROTOCOL_VERSION/submissions"
val mainHandler = new ServletContextHandler
mainHandler.setContextPath("/submissions")
mainHandler.addServlet(new ServletHolder(new KillRequestServlet(master)), "/kill/*")
mainHandler.addServlet(new ServletHolder(new StatusRequestServlet(master)), "/status/*")
mainHandler.addServlet(new ServletHolder(new SubmitRequestServlet(master)), "/create")
mainHandler.setContextPath("/")
mainHandler.addServlet(
new ServletHolder(new SubmitRequestServlet(master)), s"$pathPrefix/create")
mainHandler.addServlet(
new ServletHolder(new KillRequestServlet(master)), s"$pathPrefix/kill/*")
mainHandler.addServlet(
new ServletHolder(new StatusRequestServlet(master)), s"$pathPrefix/status/*")
mainHandler.addServlet(
new ServletHolder(new ErrorServlet), "/")
server.setHandler(mainHandler)
server.start()
val boundPort = server.getConnectors()(0).getLocalPort
Expand All @@ -83,6 +91,10 @@ private[spark] class StandaloneRestServer(
}
}

private object StandaloneRestServer {
val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
}

/**
* An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
*/
Expand Down Expand Up @@ -346,3 +358,23 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command)
}
}

/**
* A default servlet that handles error cases that are not captured by other servlets.
*/
private[spark] class ErrorServlet extends HttpServlet {
private val expectedVersion = StandaloneRestServer.PROTOCOL_VERSION
override def service(request: HttpServletRequest, response: HttpServletResponse): Unit = {
val path = request.getPathInfo
val parts = path.stripPrefix("/").split("/")
if (parts.nonEmpty) {
val version = parts.head
if (version != expectedVersion) {
response.sendError(800, s"Incompatible protocol version $version")
return
}
}
response.sendError(801,
s"Unexpected path $path: Please submit requests through /$expectedVersion/submissions/")
}
}

0 comments on commit f98660b

Please sign in to comment.