@@ -26,7 +26,7 @@ import scala.concurrent.{Awaitable, Await, Future}
26
26
import scala .language .postfixOps
27
27
28
28
import org .apache .spark .{SecurityManager , SparkConf }
29
- import org .apache .spark .util .{ThreadUtils , RpcUtils , Utils }
29
+ import org .apache .spark .util .{RpcUtils , Utils }
30
30
31
31
32
32
/**
@@ -190,8 +190,8 @@ private[spark] object RpcAddress {
190
190
/**
191
191
* An exception thrown if RpcTimeout modifies a [[TimeoutException ]].
192
192
*/
193
- private [rpc] class RpcTimeoutException (message : String )
194
- extends TimeoutException (message)
193
+ private [rpc] class RpcTimeoutException (message : String , cause : TimeoutException )
194
+ extends TimeoutException (message) { initCause(cause) }
195
195
196
196
197
197
/**
@@ -209,27 +209,23 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
209
209
def message : String = description
210
210
211
211
/** Amends the standard message of TimeoutException to include the description */
212
- def createRpcTimeoutException (te : TimeoutException ): RpcTimeoutException = {
213
- new RpcTimeoutException (te.getMessage() + " " + description)
212
+ private def createRpcTimeoutException (te : TimeoutException ): RpcTimeoutException = {
213
+ new RpcTimeoutException (te.getMessage() + " " + description, te )
214
214
}
215
215
216
216
/**
217
- * Add a callback to the given Future so that if it completes as failed with a TimeoutException
218
- * then the timeout description is added to the message
217
+ * PartialFunction to match a TimeoutException and add the timeout description to the message
218
+ *
219
+ * @note This can be used in the recover callback of a Future to add to a TimeoutException
220
+ * Example:
221
+ * val timeout = new RpcTimeout(5 millis, "short timeout")
222
+ * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
219
223
*/
220
- def addMessageIfTimeout [T ](future : Future [T ]): Future [T ] = {
221
- future.recover {
222
- // Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
223
- case rte : RpcTimeoutException => throw new RpcTimeoutException (rte.getMessage() +
224
- " (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)" )
225
- // Any other TimeoutException get converted to a RpcTimeoutException with modified message
226
- case te : TimeoutException => throw createRpcTimeoutException(te)
227
- }(ThreadUtils .sameThread)
228
- }
229
-
230
- /** Applies the duration to create future before calling addMessageIfTimeout*/
231
- def addMessageIfTimeout [T ](f : FiniteDuration => Future [T ]): Future [T ] = {
232
- addMessageIfTimeout(f(duration))
224
+ def addMessageIfTimeout [T ]: PartialFunction [Throwable , T ] = {
225
+ // The exception has already been converted to a RpcTimeoutException so just raise it
226
+ case rte : RpcTimeoutException => throw rte
227
+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
228
+ case te : TimeoutException => throw createRpcTimeoutException(te)
233
229
}
234
230
235
231
/**
@@ -241,13 +237,7 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
241
237
def awaitResult [T ](awaitable : Awaitable [T ]): T = {
242
238
try {
243
239
Await .result(awaitable, duration)
244
- }
245
- catch {
246
- // The exception has already been converted to a RpcTimeoutException so just raise it
247
- case rte : RpcTimeoutException => throw rte
248
- // Any other TimeoutException get converted to a RpcTimeoutException with modified message
249
- case te : TimeoutException => throw createRpcTimeoutException(te)
250
- }
240
+ } catch addMessageIfTimeout
251
241
}
252
242
}
253
243
@@ -299,13 +289,10 @@ object RpcTimeout {
299
289
300
290
// Find the first set property or use the default value with the first property
301
291
val itr = timeoutPropList.iterator
302
- var foundProp = None : Option [(String , String )]
292
+ var foundProp : Option [(String , String )] = None
303
293
while (itr.hasNext && foundProp.isEmpty){
304
294
val propKey = itr.next()
305
- conf.getOption(propKey) match {
306
- case Some (prop) => foundProp = Some (propKey,prop)
307
- case None =>
308
- }
295
+ conf.getOption(propKey).foreach { prop => foundProp = Some (propKey, prop) }
309
296
}
310
297
val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
311
298
val timeout = { Utils .timeStringAsSeconds(finalProp._2) seconds }
0 commit comments