Skip to content

Commit 2f94095

Browse files
committed
[SPARK-6980] Added addMessageIfTimeout for when a Future is completed with TimeoutException
1 parent 235919b commit 2f94095

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import java.util.concurrent.TimeoutException
2222

2323
import scala.concurrent.duration.FiniteDuration
2424
import scala.concurrent.duration._
25-
import scala.concurrent.{Await, Future}
25+
import scala.concurrent.{Awaitable, Await, Future}
2626
import scala.language.postfixOps
2727

2828
import org.apache.spark.{SecurityManager, SparkConf}
29-
import org.apache.spark.util.{RpcUtils, Utils}
29+
import org.apache.spark.util.{ThreadUtils, RpcUtils, Utils}
3030

3131

3232
/**
@@ -187,6 +187,13 @@ private[spark] object RpcAddress {
187187
}
188188

189189

190+
/**
191+
* An exception thrown if RpcTimeout modifies a [[TimeoutException]].
192+
*/
193+
private[rpc] class RpcTimeoutException(message: String)
194+
extends TimeoutException(message)
195+
196+
190197
/**
191198
* Associates a timeout with a description so that a when a TimeoutException occurs, additional
192199
* context about the timeout can be amended to the exception message.
@@ -202,17 +209,44 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
202209
def message: String = description
203210

204211
/** Amends the standard message of TimeoutException to include the description */
205-
def amend(te: TimeoutException): TimeoutException = {
206-
new TimeoutException(te.getMessage() + " " + description)
212+
def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
213+
new RpcTimeoutException(te.getMessage() + " " + description)
214+
}
215+
216+
/**
217+
* Add a callback to the given Future so that if it completes as failed with a TimeoutException
218+
* then the timeout description is added to the message
219+
*/
220+
def addMessageIfTimeout[T](future: Future[T]): Future[T] = {
221+
future.recover {
222+
// Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
223+
case rte: RpcTimeoutException => throw new RpcTimeoutException(rte.getMessage() +
224+
" (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)")
225+
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
226+
case te: TimeoutException => throw createRpcTimeoutException(te)
227+
}(ThreadUtils.sameThread)
228+
}
229+
230+
/** Applies the duration to create future before calling addMessageIfTimeout*/
231+
def addMessageIfTimeout[T](f: FiniteDuration => Future[T]): Future[T] = {
232+
addMessageIfTimeout(f(duration))
207233
}
208234

209-
/** Wait on a future result to catch and amend a TimeoutException */
210-
def awaitResult[T](future: Future[T]): T = {
235+
/**
236+
* Waits for a completed result to catch and amend a TimeoutException message
237+
* @param awaitable the `Awaitable` to be awaited
238+
* @throws RpcTimeoutException if after waiting for the specified time `awaitable`
239+
* is still not ready
240+
*/
241+
def awaitResult[T](awaitable: Awaitable[T]): T = {
211242
try {
212-
Await.result(future, duration)
243+
Await.result(awaitable, duration)
213244
}
214245
catch {
215-
case te: TimeoutException => throw amend(te)
246+
// The exception has already been converted to a RpcTimeoutException so just raise it
247+
case rte: RpcTimeoutException => throw rte
248+
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
249+
case te: TimeoutException => throw createRpcTimeoutException(te)
216250
}
217251
}
218252
}

core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,10 @@ private[spark] class AkkaRpcEnv private[akka] (
213213

214214
override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
215215
import actorSystem.dispatcher
216-
actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
217-
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
216+
defaultLookupTimeout.addMessageIfTimeout(
217+
actorSystem.actorSelection(uri).resolveOne(_).
218+
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
219+
)
218220
}
219221

220222
override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
@@ -295,18 +297,20 @@ private[akka] class AkkaRpcEndpointRef(
295297
}
296298

297299
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
298-
actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
299-
// The function will run in the calling thread, so it should be short and never block.
300-
case msg @ AkkaMessage(message, reply) =>
301-
if (reply) {
302-
logError(s"Receive $msg but the sender cannot reply")
303-
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
304-
} else {
305-
Future.successful(message)
306-
}
307-
case AkkaFailure(e) =>
308-
Future.failed(e)
309-
}(ThreadUtils.sameThread).mapTo[T]
300+
timeout.addMessageIfTimeout(
301+
actorRef.ask(AkkaMessage(message, true))(_).flatMap {
302+
// The function will run in the calling thread, so it should be short and never block.
303+
case msg @ AkkaMessage(message, reply) =>
304+
if (reply) {
305+
logError(s"Receive $msg but the sender cannot reply")
306+
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
307+
} else {
308+
Future.successful(message)
309+
}
310+
case AkkaFailure(e) =>
311+
Future.failed(e)
312+
}(ThreadUtils.sameThread).mapTo[T]
313+
)
310314
}
311315

312316
override def toString: String = s"${getClass.getSimpleName}($actorRef)"

0 commit comments

Comments
 (0)