99package org .elasticsearch .action .support .nodes ;
1010
1111import org .elasticsearch .ElasticsearchException ;
12+ import org .elasticsearch .action .ActionListener ;
1213import org .elasticsearch .action .FailedNodeException ;
1314import org .elasticsearch .action .support .ActionFilters ;
1415import org .elasticsearch .action .support .PlainActionFuture ;
15- import org .elasticsearch .action .support .broadcast .node .TransportBroadcastByNodeActionTests ;
16+ import org .elasticsearch .action .support .RefCountingListener ;
17+ import org .elasticsearch .action .support .SubscribableListener ;
1618import org .elasticsearch .cluster .ClusterName ;
1719import org .elasticsearch .cluster .ClusterState ;
1820import org .elasticsearch .cluster .node .DiscoveryNode ;
5557import java .util .Set ;
5658import java .util .concurrent .Executor ;
5759import java .util .concurrent .TimeUnit ;
60+ import java .util .concurrent .atomic .AtomicInteger ;
61+ import java .util .function .Function ;
62+ import java .util .function .ObjLongConsumer ;
5863
5964import static java .util .Collections .emptyMap ;
6065import static org .elasticsearch .test .ClusterServiceUtils .createClusterService ;
@@ -118,7 +123,10 @@ public void testResponseAggregation() {
118123 final TestTransportNodesAction action = getTestTransportNodesAction ();
119124
120125 final PlainActionFuture <TestNodesResponse > listener = new PlainActionFuture <>();
121- action .execute (null , new TestNodesRequest (), listener );
126+ action .execute (null , new TestNodesRequest (), listener .delegateFailure ((l , response ) -> {
127+ assertTrue (response .getNodes ().stream ().allMatch (TestNodeResponse ::hasReferences ));
128+ l .onResponse (response );
129+ }));
122130 assertFalse (listener .isDone ());
123131
124132 final Set <String > failedNodeIds = new HashSet <>();
@@ -127,7 +135,9 @@ public void testResponseAggregation() {
127135 for (CapturingTransport .CapturedRequest capturedRequest : transport .getCapturedRequestsAndClear ()) {
128136 if (randomBoolean ()) {
129137 successfulNodes .add (capturedRequest .node ());
130- transport .handleResponse (capturedRequest .requestId (), new TestNodeResponse (capturedRequest .node ()));
138+ final var response = new TestNodeResponse (capturedRequest .node ());
139+ transport .handleResponse (capturedRequest .requestId (), response );
140+ assertFalse (response .hasReferences ()); // response is copied (via the wire protocol) so this instance is released
131141 } else {
132142 failedNodeIds .add (capturedRequest .node ().getId ());
133143 if (randomBoolean ()) {
@@ -138,7 +148,16 @@ public void testResponseAggregation() {
138148 }
139149 }
140150
141- TestNodesResponse response = listener .actionGet (10 , TimeUnit .SECONDS );
151+ final TestNodesResponse response = listener .actionGet (10 , TimeUnit .SECONDS );
152+
153+ final var allResponsesReleasedListener = new SubscribableListener <Void >();
154+ try (var listeners = new RefCountingListener (allResponsesReleasedListener )) {
155+ for (final var nodeResponse : response .getNodes ()) {
156+ nodeResponse .addCloseListener (listeners .acquire ());
157+ }
158+ }
159+ safeAwait (allResponsesReleasedListener );
160+ assertTrue (response .getNodes ().stream ().noneMatch (TestNodeResponse ::hasReferences ));
142161
143162 for (TestNodeResponse nodeResponse : response .getNodes ()) {
144163 assertThat (successfulNodes , Matchers .hasItem (nodeResponse .getNode ()));
@@ -164,7 +183,7 @@ public void testResponsesReleasedOnCancellation() {
164183 final CancellableTask cancellableTask = new CancellableTask (randomLong (), "transport" , "action" , "" , null , emptyMap ());
165184 final PlainActionFuture <TestNodesResponse > listener = new PlainActionFuture <>();
166185 action .execute (cancellableTask , new TestNodesRequest (), listener .delegateResponse ((l , e ) -> {
167- assert Thread . currentThread (). getName (). contains ( "[" + ThreadPool .Names .GENERIC + "]" );
186+ assert ThreadPool . assertCurrentThreadPool ( ThreadPool .Names .GENERIC );
168187 l .onFailure (e );
169188 }));
170189
@@ -173,13 +192,31 @@ public void testResponsesReleasedOnCancellation() {
173192 );
174193 Randomness .shuffle (capturedRequests );
175194
195+ final AtomicInteger liveResponseCount = new AtomicInteger ();
196+ final Function <DiscoveryNode , TestNodeResponse > responseCreator = node -> {
197+ liveResponseCount .incrementAndGet ();
198+ final var testNodeResponse = new TestNodeResponse (node );
199+ testNodeResponse .addCloseListener (ActionListener .running (liveResponseCount ::decrementAndGet ));
200+ return testNodeResponse ;
201+ };
202+
203+ final ObjLongConsumer <TestNodeResponse > responseSender = (response , requestId ) -> {
204+ try {
205+ // transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
206+ transport .getTransportResponseHandler (requestId ).handleResponse (response );
207+ } finally {
208+ response .decRef ();
209+ }
210+ };
211+
176212 final ReachabilityChecker reachabilityChecker = new ReachabilityChecker ();
177213 final Runnable nextRequestProcessor = () -> {
178214 var capturedRequest = capturedRequests .remove (0 );
179215 if (randomBoolean ()) {
180- // transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
181- transport .getTransportResponseHandler (capturedRequest .requestId ())
182- .handleResponse (reachabilityChecker .register (new TestNodeResponse (capturedRequest .node ())));
216+ responseSender .accept (
217+ reachabilityChecker .register (responseCreator .apply (capturedRequest .node ())),
218+ capturedRequest .requestId ()
219+ );
183220 } else {
184221 // handleRemoteError may de/serialize the exception, releasing it early, so just use handleLocalError
185222 transport .handleLocalError (
@@ -200,20 +237,23 @@ public void testResponsesReleasedOnCancellation() {
200237
201238 // responses captured before cancellation are now unreachable
202239 reachabilityChecker .ensureUnreachable ();
240+ assertEquals (0 , liveResponseCount .get ());
203241
204242 while (capturedRequests .size () > 0 ) {
205243 // a response sent after cancellation is dropped immediately
206244 assertFalse (listener .isDone ());
207245 nextRequestProcessor .run ();
208246 reachabilityChecker .ensureUnreachable ();
247+ assertEquals (0 , liveResponseCount .get ());
209248 }
210249
211250 expectThrows (TaskCancelledException .class , () -> listener .actionGet (10 , TimeUnit .SECONDS ));
251+ assertTrue (cancellableTask .isCancelled ()); // keep task alive
212252 }
213253
214254 @ BeforeClass
215255 public static void startThreadPool () {
216- THREAD_POOL = new TestThreadPool (TransportBroadcastByNodeActionTests .class .getSimpleName ());
256+ THREAD_POOL = new TestThreadPool (TransportNodesActionTests .class .getSimpleName ());
217257 }
218258
219259 @ AfterClass
@@ -268,11 +308,9 @@ public void tearDown() throws Exception {
268308
269309 public TestTransportNodesAction getTestTransportNodesAction () {
270310 return new TestTransportNodesAction (
271- THREAD_POOL ,
272311 clusterService ,
273312 transportService ,
274313 new ActionFilters (Collections .emptySet ()),
275- TestNodesRequest ::new ,
276314 TestNodeRequest ::new ,
277315 THREAD_POOL .executor (ThreadPool .Names .GENERIC )
278316 );
@@ -302,11 +340,9 @@ private static class TestTransportNodesAction extends TransportNodesAction<
302340 TestNodeResponse > {
303341
304342 TestTransportNodesAction (
305- ThreadPool threadPool ,
306343 ClusterService clusterService ,
307344 TransportService transportService ,
308345 ActionFilters actionFilters ,
309- Writeable .Reader <TestNodesRequest > request ,
310346 Writeable .Reader <TestNodeRequest > nodeRequest ,
311347 Executor nodeExecutor
312348 ) {
@@ -319,7 +355,7 @@ protected TestNodesResponse newResponse(
319355 List <TestNodeResponse > responses ,
320356 List <FailedNodeException > failures
321357 ) {
322- return new TestNodesResponse (clusterService .getClusterName (), request , responses , failures );
358+ return new TestNodesResponse (clusterService .getClusterName (), responses , failures );
323359 }
324360
325361 @ Override
@@ -350,7 +386,7 @@ private static class DataNodesOnlyTransportNodesAction extends TestTransportNode
350386 Writeable .Reader <TestNodeRequest > nodeRequest ,
351387 Executor nodeExecutor
352388 ) {
353- super (threadPool , clusterService , transportService , actionFilters , request , nodeRequest , nodeExecutor );
389+ super (clusterService , transportService , actionFilters , nodeRequest , nodeExecutor );
354390 }
355391
356392 @ Override
@@ -371,16 +407,8 @@ private static class TestNodesRequest extends BaseNodesRequest<TestNodesRequest>
371407
372408 private static class TestNodesResponse extends BaseNodesResponse <TestNodeResponse > {
373409
374- private final TestNodesRequest request ;
375-
376- TestNodesResponse (
377- ClusterName clusterName ,
378- TestNodesRequest request ,
379- List <TestNodeResponse > nodeResponses ,
380- List <FailedNodeException > failures
381- ) {
410+ TestNodesResponse (ClusterName clusterName , List <TestNodeResponse > nodeResponses , List <FailedNodeException > failures ) {
382411 super (clusterName , nodeResponses , failures );
383- this .request = request ;
384412 }
385413
386414 @ Override
@@ -425,6 +453,10 @@ public boolean hasReferences() {
425453 }
426454
427455 private static class TestNodeResponse extends BaseNodeResponse {
456+
457+ private final SubscribableListener <Void > onClose = new SubscribableListener <>();
458+ private final RefCounted refCounted = AbstractRefCounted .of (() -> onClose .onResponse (null ));
459+
428460 TestNodeResponse () {
429461 this (mock (DiscoveryNode .class ));
430462 }
@@ -436,6 +468,30 @@ private static class TestNodeResponse extends BaseNodeResponse {
436468 protected TestNodeResponse (StreamInput in ) throws IOException {
437469 super (in );
438470 }
471+
472+ @ Override
473+ public void incRef () {
474+ refCounted .incRef ();
475+ }
476+
477+ @ Override
478+ public boolean tryIncRef () {
479+ return refCounted .tryIncRef ();
480+ }
481+
482+ @ Override
483+ public boolean decRef () {
484+ return refCounted .decRef ();
485+ }
486+
487+ @ Override
488+ public boolean hasReferences () {
489+ return refCounted .hasReferences ();
490+ }
491+
492+ void addCloseListener (ActionListener <Void > listener ) {
493+ onClose .addListener (listener );
494+ }
439495 }
440496
441497}
0 commit comments