68
68
import java .util .stream .Collectors ;
69
69
import java .util .stream .StreamSupport ;
70
70
71
+ import static org .hamcrest .Matchers .containsString ;
72
+ import static org .hamcrest .Matchers .either ;
71
73
import static org .hamcrest .Matchers .empty ;
72
74
import static org .hamcrest .Matchers .equalTo ;
73
75
import static org .hamcrest .Matchers .hasSize ;
@@ -146,7 +148,7 @@ public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception {
146
148
beforeExecuteLatches .get (req ).countDown ();
147
149
}
148
150
cancelFuture .actionGet ();
149
- mainTaskFuture . actionGet ( );
151
+ waitForMainTask ( mainTaskFuture );
150
152
assertBusy (() -> {
151
153
for (DiscoveryNode node : nodes ) {
152
154
TaskManager taskManager = internalCluster ().getInstance (TransportService .class , node .getName ()).getTaskManager ();
@@ -177,7 +179,7 @@ public void testCancelTaskMultipleTimes() throws Exception {
177
179
}
178
180
assertThat (cancelFuture .actionGet ().getTaskFailures (), empty ());
179
181
assertThat (cancelFuture .actionGet ().getTaskFailures (), empty ());
180
- mainTaskFuture . actionGet ( );
182
+ waitForMainTask ( mainTaskFuture );
181
183
CancelTasksResponse cancelError = client ().admin ().cluster ().prepareCancelTasks ()
182
184
.setTaskId (taskId ).waitForCompletion (randomBoolean ()).get ();
183
185
assertThat (cancelError .getNodeFailures (), hasSize (1 ));
@@ -204,7 +206,7 @@ public void testDoNotWaitForCompletion() throws Exception {
204
206
for (ChildRequest r : childRequests ) {
205
207
beforeExecuteLatches .get (r ).countDown ();
206
208
}
207
- mainTaskFuture . actionGet ( );
209
+ waitForMainTask ( mainTaskFuture );
208
210
}
209
211
210
212
TaskId getMainTaskId () {
@@ -214,6 +216,17 @@ TaskId getMainTaskId() {
214
216
return listTasksResponse .getTasks ().get (0 ).getTaskId ();
215
217
}
216
218
219
+ void waitForMainTask (ActionFuture <MainResponse > mainTask ) {
220
+ try {
221
+ mainTask .actionGet ();
222
+ } catch (Exception e ) {
223
+ final Throwable cause = ExceptionsHelper .unwrap (e , TaskCancelledException .class );
224
+ assertThat (cause .getMessage (),
225
+ either (equalTo ("The parent task was cancelled, shouldn't start any child tasks" ))
226
+ .or (containsString ("Task cancelled before it started:" )));
227
+ }
228
+ }
229
+
217
230
public static class MainRequest extends ActionRequest {
218
231
final List <ChildRequest > childRequests ;
219
232
@@ -302,7 +315,7 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId,
302
315
return new CancellableTask (id , type , action , getDescription (), parentTaskId , headers ) {
303
316
@ Override
304
317
public boolean shouldCancelChildrenOnCancellation () {
305
- return false ;
318
+ return shouldCancelChildrenOnCancellation ;
306
319
}
307
320
};
308
321
} else {
@@ -364,15 +377,15 @@ protected void doExecute(Task task, MainRequest request, ActionListener<MainResp
364
377
protected void startChildTask (TaskId parentTaskId , ChildRequest childRequest , ActionListener <ChildResponse > listener ) {
365
378
childRequest .setParentTask (parentTaskId );
366
379
final CountDownLatch completeLatch = completedLatches .get (childRequest );
380
+ LatchedActionListener <ChildResponse > latchedListener = new LatchedActionListener <>(listener , completeLatch );
367
381
transportService .getThreadPool ().generic ().execute (new AbstractRunnable () {
368
382
@ Override
369
383
public void onFailure (Exception e ) {
370
- throw new AssertionError (e );
384
+ listener . onFailure (e );
371
385
}
372
386
373
387
@ Override
374
388
protected void doRun () {
375
- LatchedActionListener <ChildResponse > latchedListener = new LatchedActionListener <>(listener , completeLatch );
376
389
if (client .getLocalNodeId ().equals (childRequest .targetNode .getId ()) && randomBoolean ()) {
377
390
try {
378
391
client .executeLocally (TransportChildAction .ACTION , childRequest , latchedListener );
0 commit comments