Skip to content

Commit

Permalink
Clean up uses of exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed Jan 30, 2015
1 parent 914fdff commit 9581df7
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class DriverStatusRequest extends SubmitRestProtocolRequest {
def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class DriverStatusResponse extends SubmitRestProtocolResponse {
def setWorkerId(s: String): this.type = setField(workerId, s)
def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s)

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(success)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package org.apache.spark.deploy.rest
* An error response message used in the REST application submission protocol.
*/
class ErrorResponse extends SubmitRestProtocolResponse {
override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class KillDriverRequest extends SubmitRestProtocolRequest {
def getDriverId: String = driverId.toString
def setDriverId(s: String): this.type = setField(driverId, s)

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class KillDriverResponse extends SubmitRestProtocolResponse {
def setDriverId(s: String): this.type = setField(driverId, s)
def setSuccess(s: String): this.type = setBooleanField(success, s)

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(driverId)
assertFieldIsSet(success)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null
}

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(appName)
assertFieldIsSet(appResource)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class SubmitDriverResponse extends SubmitRestProtocolResponse {
def setSuccess(s: String): this.type = setBooleanField(success, s)
def setDriverId(s: String): this.type = setField(driverId, s)

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(success)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,8 @@ package org.apache.spark.deploy.rest
*/
class SubmitRestProtocolField[T](val name: String) {
protected var value: Option[T] = None

/** Return the value or throw an [[IllegalArgumentException]] if the value is not set. */
def getValue: T = {
value.getOrElse {
throw new IllegalAccessException(s"Value not set in field '$name'!")
}
}

def isSet: Boolean = value.isDefined
def getValueOption: Option[T] = value
def getValue: Option[T] = value
def setValue(v: T): Unit = { value = Some(v) }
def clearValue(): Unit = { value = None }
override def toString: String = value.map(_.toString).orNull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,60 @@ abstract class SubmitRestProtocolMessage {
pretty(parse(mapper.writeValueAsString(this)))
}

/** Assert the validity of the message. */
def validate(): Unit = {
assert(action != null, s"The action field is missing in $messageType!")
/**
* Assert the validity of the message.
* If the validation fails, throw a [[SubmitRestValidationException]].
*/
final def validate(): Unit = {
try {
doValidate()
} catch {
case e: Exception =>
throw new SubmitRestValidationException(
s"Validation of message $messageType failed!", e)
}
}

/** Assert the validity of the message */
protected def doValidate(): Unit = {
assert(action != null, s"The action field is missing.")
assertFieldIsSet(sparkVersion)
}

/** Assert that the specified field is set in this message. */
protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = {
assert(field.isSet, s"Field '${field.name}' is missing in $messageType!")
assert(field.isSet, s"Field '${field.name}' is missing.")
}

/**
* Assert a condition when validating this message.
* If the assertion fails, throw a [[SubmitRestValidationException]].
*/
protected def assert(condition: Boolean, failMessage: String): Unit = {
if (!condition) { throw new SubmitRestValidationException(failMessage) }
}

/** Set the field to the given value, or clear the field if the value is null. */
protected def setField(field: SubmitRestProtocolField[String], value: String): this.type = {
if (value == null) { field.clearValue() } else { field.setValue(value) }
protected def setField(f: SubmitRestProtocolField[String], v: String): this.type = {
if (v == null) { f.clearValue() } else { f.setValue(v) }
this
}

/**
* Set the field to the given boolean value, or clear the field if the value is null.
* If the provided value does not represent a boolean, throw an exception.
*/
protected def setBooleanField(
field: SubmitRestProtocolField[Boolean],
value: String): this.type = {
if (value == null) { field.clearValue() } else { field.setValue(value.toBoolean) }
protected def setBooleanField(f: SubmitRestProtocolField[Boolean], v: String): this.type = {
if (v == null) { f.clearValue() } else { f.setValue(v.toBoolean) }
this
}

/**
* Set the field to the given numeric value, or clear the field if the value is null.
* If the provided value does not represent a numeric, throw an exception.
*/
protected def setNumericField(
field: SubmitRestProtocolField[Int],
value: String): this.type = {
if (value == null) { field.clearValue() } else { field.setValue(value.toInt) }
protected def setNumericField(f: SubmitRestProtocolField[Int], v: String): this.type = {
if (v == null) { f.clearValue() } else { f.setValue(v.toInt) }
this
}

Expand All @@ -111,12 +129,9 @@ abstract class SubmitRestProtocolMessage {
* If the provided value does not represent a memory value, throw an exception.
* Valid examples of memory values include "512m", "24g", and "128000".
*/
protected def setMemoryField(
field: SubmitRestProtocolField[String],
value: String): this.type = {
Utils.memoryStringToMb(value)
setField(field, value)
this
protected def setMemoryField(f: SubmitRestProtocolField[String], v: String): this.type = {
Utils.memoryStringToMb(v)
setField(f, v)
}
}

Expand All @@ -142,6 +157,14 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
override def setSparkVersion(s: String) = setServerSparkVersion(s)
}

/**
* An exception thrown if the validation of a [[SubmitRestProtocolMessage]] fails.
*/
class SubmitRestValidationException(
message: String,
cause: Exception = null)
extends Exception(message, cause)

object SubmitRestProtocolMessage {
private val mapper = new ObjectMapper
private val packagePrefix = this.getClass.getPackage.getName
Expand All @@ -162,7 +185,7 @@ object SubmitRestProtocolMessage {
/**
* Construct a [[SubmitRestProtocolMessage]] from its JSON representation.
*
* This method first parses the action from the JSON and uses it to infers the message type.
* This method first parses the action from the JSON and uses it to infer the message type.
* Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in
* this package. Otherwise, a [[ClassNotFoundException]] will be thrown.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi

/**
* Construct the appropriate response message based on the type of the request message.
* If an [[IllegalArgumentException]] is thrown, construct an error message instead.
* If an exception is thrown, construct an error message instead.
*/
private def constructResponseMessage(
request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = {
Expand All @@ -122,15 +122,14 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
}
} catch {
case e: IllegalArgumentException => handleError(formatException(e))
case e: Exception => handleError(formatException(e))
}
// Validate the response message to ensure that it is correctly constructed. If it is not,
// propagate the exception back to the client and signal that it is a server error.
try {
response.validate()
} catch {
case e: IllegalArgumentException =>
handleError("Internal server error: " + formatException(e))
case e: Exception => handleError("Internal server error: " + formatException(e))
}
response
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B
assert(killSuccess === "true")
assert(statusSuccess === "true")
assert(driverState === DriverState.KILLED.toString)
// we should not see the expected results because we killed the driver
intercept[TestFailedException] { validateResult(resultsFile, numbers, size) }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,22 @@ class SubmitRestProtocolSuite extends FunSuite {

test("validate") {
val request = new DummyRequest
intercept[AssertionError] { request.validate() } // missing everything
intercept[SubmitRestValidationException] { request.validate() } // missing everything
request.setSparkVersion("1.4.8")
intercept[AssertionError] { request.validate() } // missing name and age
intercept[SubmitRestValidationException] { request.validate() } // missing name and age
request.setName("something")
intercept[AssertionError] { request.validate() } // missing only age
intercept[SubmitRestValidationException] { request.validate() } // missing only age
request.setAge("2")
intercept[AssertionError] { request.validate() } // age too low
intercept[SubmitRestValidationException] { request.validate() } // age too low
request.setAge("10")
request.validate() // everything is set
request.setSparkVersion(null)
intercept[AssertionError] { request.validate() } // missing only Spark version
intercept[SubmitRestValidationException] { request.validate() } // missing only Spark version
request.setSparkVersion("1.2.3")
request.setName(null)
intercept[AssertionError] { request.validate() } // missing only name
intercept[SubmitRestValidationException] { request.validate() } // missing only name
request.setMessage("not-setting-name")
intercept[AssertionError] { request.validate() } // still missing name
intercept[SubmitRestValidationException] { request.validate() } // still missing name
}

test("request to and from JSON") {
Expand Down Expand Up @@ -119,7 +119,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("SubmitDriverRequest") {
val message = new SubmitDriverRequest
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
intercept[IllegalArgumentException] { message.setDriverCores("one hundred feet") }
intercept[IllegalArgumentException] { message.setSuperviseDriver("nope, never") }
intercept[IllegalArgumentException] { message.setTotalExecutorCores("two men") }
Expand Down Expand Up @@ -181,7 +181,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("SubmitDriverResponse") {
val message = new SubmitDriverResponse
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
intercept[IllegalArgumentException] { message.setSuccess("maybe not") }
message.setSparkVersion("1.2.3")
message.setDriverId("driver_123")
Expand All @@ -199,7 +199,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("KillDriverRequest") {
val message = new KillDriverRequest
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
message.setSparkVersion("1.2.3")
message.setDriverId("driver_123")
message.validate()
Expand All @@ -214,7 +214,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("KillDriverResponse") {
val message = new KillDriverResponse
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
intercept[IllegalArgumentException] { message.setSuccess("maybe not") }
message.setSparkVersion("1.2.3")
message.setDriverId("driver_123")
Expand All @@ -232,7 +232,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("DriverStatusRequest") {
val message = new DriverStatusRequest
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
message.setSparkVersion("1.2.3")
message.setDriverId("driver_123")
message.validate()
Expand All @@ -247,7 +247,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("DriverStatusResponse") {
val message = new DriverStatusResponse
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
intercept[IllegalArgumentException] { message.setSuccess("maybe") }
message.setSparkVersion("1.2.3")
message.setDriverId("driver_123")
Expand All @@ -269,7 +269,7 @@ class SubmitRestProtocolSuite extends FunSuite {

test("ErrorResponse") {
val message = new ErrorResponse
intercept[AssertionError] { message.validate() }
intercept[SubmitRestValidationException] { message.validate() }
message.setSparkVersion("1.2.3")
message.setMessage("Field not found in submit request: X")
message.validate()
Expand Down Expand Up @@ -412,10 +412,10 @@ private class DummyRequest extends SubmitRestProtocolRequest {
def setAge(s: String): this.type = setNumericField(age, s)
def setName(s: String): this.type = setField(name, s)

override def validate(): Unit = {
super.validate()
protected override def doValidate(): Unit = {
super.doValidate()
assertFieldIsSet(name)
assertFieldIsSet(age)
assert(age.getValue > 5, "Not old enough!")
assert(age.getValue.get > 5, "Not old enough!")
}
}

0 comments on commit 9581df7

Please sign in to comment.