@@ -22,11 +22,11 @@ import java.util.concurrent.TimeoutException
22
22
23
23
import scala .concurrent .duration .FiniteDuration
24
24
import scala .concurrent .duration ._
25
- import scala .concurrent .{Await , Future }
25
+ import scala .concurrent .{Awaitable , Await , Future }
26
26
import scala .language .postfixOps
27
27
28
28
import org .apache .spark .{SecurityManager , SparkConf }
29
- import org .apache .spark .util .{RpcUtils , Utils }
29
+ import org .apache .spark .util .{ThreadUtils , RpcUtils , Utils }
30
30
31
31
32
32
/**
@@ -187,6 +187,13 @@ private[spark] object RpcAddress {
187
187
}
188
188
189
189
190
+ /**
191
+ * An exception thrown if RpcTimeout modifies a [[TimeoutException ]].
192
+ */
193
+ private [rpc] class RpcTimeoutException (message : String )
194
+ extends TimeoutException (message)
195
+
196
+
190
197
/**
191
198
* Associates a timeout with a description so that a when a TimeoutException occurs, additional
192
199
* context about the timeout can be amended to the exception message.
@@ -202,17 +209,44 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
202
209
def message : String = description
203
210
204
211
/** Amends the standard message of TimeoutException to include the description */
205
- def amend (te : TimeoutException ): TimeoutException = {
206
- new TimeoutException (te.getMessage() + " " + description)
212
+ def createRpcTimeoutException (te : TimeoutException ): RpcTimeoutException = {
213
+ new RpcTimeoutException (te.getMessage() + " " + description)
214
+ }
215
+
216
+ /**
217
+ * Add a callback to the given Future so that if it completes as failed with a TimeoutException
218
+ * then the timeout description is added to the message
219
+ */
220
+ def addMessageIfTimeout [T ](future : Future [T ]): Future [T ] = {
221
+ future.recover {
222
+ // Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
223
+ case rte : RpcTimeoutException => throw new RpcTimeoutException (rte.getMessage() +
224
+ " (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)" )
225
+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
226
+ case te : TimeoutException => throw createRpcTimeoutException(te)
227
+ }(ThreadUtils .sameThread)
228
+ }
229
+
230
+ /** Applies the duration to create future before calling addMessageIfTimeout*/
231
+ def addMessageIfTimeout [T ](f : FiniteDuration => Future [T ]): Future [T ] = {
232
+ addMessageIfTimeout(f(duration))
207
233
}
208
234
209
- /** Wait on a future result to catch and amend a TimeoutException */
210
- def awaitResult [T ](future : Future [T ]): T = {
235
+ /**
236
+ * Waits for a completed result to catch and amend a TimeoutException message
237
+ * @param awaitable the `Awaitable` to be awaited
238
+ * @throws RpcTimeoutException if after waiting for the specified time `awaitable`
239
+ * is still not ready
240
+ */
241
+ def awaitResult [T ](awaitable : Awaitable [T ]): T = {
211
242
try {
212
- Await .result(future , duration)
243
+ Await .result(awaitable , duration)
213
244
}
214
245
catch {
215
- case te : TimeoutException => throw amend(te)
246
+ // The exception has already been converted to a RpcTimeoutException so just raise it
247
+ case rte : RpcTimeoutException => throw rte
248
+ // Any other TimeoutException get converted to a RpcTimeoutException with modified message
249
+ case te : TimeoutException => throw createRpcTimeoutException(te)
216
250
}
217
251
}
218
252
}
0 commit comments