From 484bd2172b847433c989d7c450fbbc99dddb1f56 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 19 Jan 2015 17:02:25 -0800 Subject: [PATCH] Specify an ordering for fields in SubmitDriverRequestMessage Previously APP_ARGs, SPARK_PROPERTYs and ENVIRONMENT_VARIABLEs will appear in the JSON at random places. Now they are grouped together at the end of the JSON blob. --- .../rest/SubmitDriverRequestMessage.scala | 18 ++++++++++ .../rest/SubmitRestProtocolMessage.scala | 34 +++++++++++-------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index e55cb69ed112d..30c203e003f11 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -83,6 +83,8 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag SubmitDriverRequestField.ACTION, SubmitDriverRequestField.requiredFields) { + import SubmitDriverRequestField._ + // Ensure continuous range of app arg indices starting from 0 override def validate(): this.type = { import SubmitDriverRequestField._ @@ -93,6 +95,22 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag } super.validate() } + + // List the fields in the following order: + // ACTION < SPARK_VERSION < * < APP_ARG < SPARK_PROPERTY < ENVIRONMENT_VARIABLE < MESSAGE + protected override def sortedFields: Seq[(SubmitRestProtocolField, String)] = { + fields.toSeq.sortBy { case (k, _) => + k match { + case ACTION => 0 + case SPARK_VERSION => 1 + case APP_ARG(index) => 10 + index + case SPARK_PROPERTY(propKey) => 100 + case ENVIRONMENT_VARIABLE(envKey) => 1000 + case MESSAGE => Int.MaxValue + case _ => 2 + } + } + } } private[spark] object SubmitDriverRequestMessage extends SubmitRestProtocolMessageCompanion { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 77d38c40ae80d..db5a42d7b17da 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -24,7 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST._ import org.apache.spark.{Logging, SparkException} -import org.apache.spark.util.Utils +import org.apache.spark.util.{JsonProtocol, Utils} /** * A field used in a SubmitRestProtocolMessage. @@ -32,8 +32,9 @@ import org.apache.spark.util.Utils */ private[spark] abstract class SubmitRestProtocolField private[spark] object SubmitRestProtocolField { - /** Return whether the provided field name refers to the ACTION field. */ def isActionField(field: String): Boolean = field == "ACTION" + def isSparkVersionField(field: String): Boolean = field == "SPARK_VERSION" + def isMessageField(field: String): Boolean = field == "MESSAGE" } /** @@ -125,23 +126,26 @@ private[spark] abstract class SubmitRestProtocolMessage( /** Return the JSON representation of this message. */ def toJson: String = { - val stringFields = fields + val jsonFields = sortedFields .filter { case (_, v) => v != null } - .map { case (k, v) => (k.toString, v) } - val jsonFields = fieldsToJson(stringFields) - pretty(render(jsonFields)) + .map { case (k, v) => JField(k.toString, JString(v)) } + .toList + pretty(render(JObject(jsonFields))) } /** - * Return the JSON representation of the message fields, putting ACTION first. - * This assumes that applying `org.apache.spark.util.JsonProtocol.mapFromJson` - * to the result yields the original input. + * Return a list of (field, value) pairs with the following ordering: + * ACTION < SPARK_VERSION < * < MESSAGE */ - private def fieldsToJson(fields: Map[String, String]): JValue = { - val jsonFields = fields.toList - .sortBy { case (k, _) => if (isActionField(k)) 0 else 1 } - .map { case (k, v) => JField(k, JString(v)) } - JObject(jsonFields) + protected def sortedFields: Seq[(SubmitRestProtocolField, String)] = { + fields.toSeq.sortBy { case (k, _) => + k.toString match { + case x if isActionField(x) => 0 + case x if isSparkVersionField(x) => 1 + case x if isMessageField(x) => Int.MaxValue + case _ => 2 + } + } } } @@ -155,7 +159,7 @@ private[spark] object SubmitRestProtocolMessage { * If such a field does not exist in the JSON, throw an exception. */ def fromJson(json: String): SubmitRestProtocolMessage = { - val fields = org.apache.spark.util.JsonProtocol.mapFromJson(parse(json)) + val fields = JsonProtocol.mapFromJson(parse(json)) val action = fields .flatMap { case (k, v) => if (isActionField(k)) Some(v) else None } .headOption