Skip to content

Commit 1607a5f

Browse files
committed
[SPARK-6980] Changed addMessageIfTimeout to PartialFunction, cleanup from PR comments
1 parent 2f94095 commit 1607a5f

File tree

2 files changed

+37
-50
lines changed

2 files changed

+37
-50
lines changed

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.concurrent.{Awaitable, Await, Future}
2626
import scala.language.postfixOps
2727

2828
import org.apache.spark.{SecurityManager, SparkConf}
29-
import org.apache.spark.util.{ThreadUtils, RpcUtils, Utils}
29+
import org.apache.spark.util.{RpcUtils, Utils}
3030

3131

3232
/**
@@ -190,8 +190,8 @@ private[spark] object RpcAddress {
190190
/**
191191
* An exception thrown if RpcTimeout modifies a [[TimeoutException]].
192192
*/
193-
private[rpc] class RpcTimeoutException(message: String)
194-
extends TimeoutException(message)
193+
private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
194+
extends TimeoutException(message) { initCause(cause) }
195195

196196

197197
/**
@@ -209,27 +209,23 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
209209
def message: String = description
210210

211211
/** Amends the standard message of TimeoutException to include the description */
212-
def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
213-
new RpcTimeoutException(te.getMessage() + " " + description)
212+
private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
213+
new RpcTimeoutException(te.getMessage() + " " + description, te)
214214
}
215215

216216
/**
217-
* Add a callback to the given Future so that if it completes as failed with a TimeoutException
218-
* then the timeout description is added to the message
217+
* PartialFunction to match a TimeoutException and add the timeout description to the message
218+
*
219+
* @note This can be used in the recover callback of a Future to add to a TimeoutException
220+
* Example:
221+
* val timeout = new RpcTimeout(5 millis, "short timeout")
222+
* Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
219223
*/
220-
def addMessageIfTimeout[T](future: Future[T]): Future[T] = {
221-
future.recover {
222-
// Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
223-
case rte: RpcTimeoutException => throw new RpcTimeoutException(rte.getMessage() +
224-
" (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)")
225-
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
226-
case te: TimeoutException => throw createRpcTimeoutException(te)
227-
}(ThreadUtils.sameThread)
228-
}
229-
230-
/** Applies the duration to create future before calling addMessageIfTimeout*/
231-
def addMessageIfTimeout[T](f: FiniteDuration => Future[T]): Future[T] = {
232-
addMessageIfTimeout(f(duration))
224+
def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
225+
// The exception has already been converted to a RpcTimeoutException so just raise it
226+
case rte: RpcTimeoutException => throw rte
227+
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
228+
case te: TimeoutException => throw createRpcTimeoutException(te)
233229
}
234230

235231
/**
@@ -241,13 +237,7 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
241237
def awaitResult[T](awaitable: Awaitable[T]): T = {
242238
try {
243239
Await.result(awaitable, duration)
244-
}
245-
catch {
246-
// The exception has already been converted to a RpcTimeoutException so just raise it
247-
case rte: RpcTimeoutException => throw rte
248-
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
249-
case te: TimeoutException => throw createRpcTimeoutException(te)
250-
}
240+
} catch addMessageIfTimeout
251241
}
252242
}
253243

@@ -299,13 +289,10 @@ object RpcTimeout {
299289

300290
// Find the first set property or use the default value with the first property
301291
val itr = timeoutPropList.iterator
302-
var foundProp = None: Option[(String, String)]
292+
var foundProp: Option[(String, String)] = None
303293
while (itr.hasNext && foundProp.isEmpty){
304294
val propKey = itr.next()
305-
conf.getOption(propKey) match {
306-
case Some(prop) => foundProp = Some(propKey,prop)
307-
case None =>
308-
}
295+
conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
309296
}
310297
val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
311298
val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds }

core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] (
213213

214214
override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
215215
import actorSystem.dispatcher
216-
defaultLookupTimeout.addMessageIfTimeout(
217-
actorSystem.actorSelection(uri).resolveOne(_).
218-
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
219-
)
216+
actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
217+
map(new AkkaRpcEndpointRef(defaultAddress, _, conf)).
218+
// this is just in case there is a timeout from creating the future in resolveOne, we want the
219+
// exception to indicate the conf that determines the timeout
220+
recover(defaultLookupTimeout.addMessageIfTimeout)
220221
}
221222

222223
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
@@ -297,20 +298,19 @@ private[akka] class AkkaRpcEndpointRef(
297298
}
298299

299300
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
300-
timeout.addMessageIfTimeout(
301-
actorRef.ask(AkkaMessage(message, true))(_).flatMap {
302-
// The function will run in the calling thread, so it should be short and never block.
303-
case msg @ AkkaMessage(message, reply) =>
304-
if (reply) {
305-
logError(s"Receive $msg but the sender cannot reply")
306-
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
307-
} else {
308-
Future.successful(message)
309-
}
310-
case AkkaFailure(e) =>
311-
Future.failed(e)
312-
}(ThreadUtils.sameThread).mapTo[T]
313-
)
301+
actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
302+
// The function will run in the calling thread, so it should be short and never block.
303+
case msg @ AkkaMessage(message, reply) =>
304+
if (reply) {
305+
logError(s"Receive $msg but the sender cannot reply")
306+
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
307+
} else {
308+
Future.successful(message)
309+
}
310+
case AkkaFailure(e) =>
311+
Future.failed(e)
312+
}(ThreadUtils.sameThread).mapTo[T].
313+
recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
314314
}
315315

316316
override def toString: String = s"${getClass.getSimpleName}($actorRef)"

0 commit comments

Comments
 (0)