Skip to content

Commit 6c13a81

Browse files
authored
Refcount responses in TransportNodesAction (#103254)
Today we `decRef()` the per-node responses just after adding them to the `responses` collection, but in fact we should keep them alive until we've constructed the final response. This commit does that.
1 parent 922e790 commit 6c13a81

File tree

4 files changed

+160
-29
lines changed

4 files changed

+160
-29
lines changed

libs/core/src/main/java/org/elasticsearch/core/Releasables.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.io.IOException;
1212
import java.io.UncheckedIOException;
1313
import java.util.Arrays;
14+
import java.util.Iterator;
1415
import java.util.concurrent.atomic.AtomicReference;
1516

1617
/** Utility methods to work with {@link Releasable}s. */
@@ -103,6 +104,24 @@ public String toString() {
103104
};
104105
}
105106

107+
/**
108+
* Similar to {@link #wrap(Iterable)} except that it accepts an {@link Iterator} of releasables. The resulting resource must therefore
109+
* only be released once.
110+
*/
111+
public static Releasable wrap(final Iterator<Releasable> releasables) {
112+
return assertOnce(wrap(new Iterable<>() {
113+
@Override
114+
public Iterator<Releasable> iterator() {
115+
return releasables;
116+
}
117+
118+
@Override
119+
public String toString() {
120+
return releasables.toString();
121+
}
122+
}));
123+
}
124+
106125
/** @see #wrap(Iterable) */
107126
public static Releasable wrap(final Releasable... releasables) {
108127
return new Releasable() {

libs/core/src/test/java/org/elasticsearch/core/ReleasablesTests.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,27 @@ public String toString() {
107107
assertEquals("wrapped[list]", wrapIterable.toString());
108108
wrapIterable.close();
109109
assertEquals(5, count.get());
110+
111+
final var wrapIterator = Releasables.wrap(new Iterator<>() {
112+
final Iterator<Releasable> innerIterator = List.of(releasable, releasable, releasable).iterator();
113+
114+
@Override
115+
public boolean hasNext() {
116+
return innerIterator.hasNext();
117+
}
118+
119+
@Override
120+
public Releasable next() {
121+
return innerIterator.next();
122+
}
123+
124+
@Override
125+
public String toString() {
126+
return "iterator";
127+
}
128+
});
129+
assertEquals("wrapped[iterator]", wrapIterator.toString());
130+
wrapIterator.close();
131+
assertEquals(8, count.get());
110132
}
111133
}

server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
import org.elasticsearch.common.io.stream.Writeable;
2626
import org.elasticsearch.common.util.concurrent.EsExecutors;
2727
import org.elasticsearch.core.CheckedConsumer;
28+
import org.elasticsearch.core.Releasables;
29+
import org.elasticsearch.tasks.CancellableTask;
2830
import org.elasticsearch.tasks.Task;
2931
import org.elasticsearch.transport.TransportChannel;
3032
import org.elasticsearch.transport.TransportRequest;
@@ -96,6 +98,23 @@ protected void doExecute(Task task, NodesRequest request, ActionListener<NodesRe
9698

9799
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
98100

101+
{
102+
addReleaseOnCancellationListener();
103+
}
104+
105+
private void addReleaseOnCancellationListener() {
106+
if (task instanceof CancellableTask cancellableTask) {
107+
cancellableTask.addListener(() -> {
108+
final List<NodeResponse> drainedResponses;
109+
synchronized (responses) {
110+
drainedResponses = List.copyOf(responses);
111+
responses.clear();
112+
}
113+
Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
114+
});
115+
}
116+
}
117+
99118
@Override
100119
protected void sendItemRequest(DiscoveryNode discoveryNode, ActionListener<NodeResponse> listener) {
101120
final var nodeRequest = newNodeRequest(request);
@@ -118,9 +137,14 @@ protected void sendItemRequest(DiscoveryNode discoveryNode, ActionListener<NodeR
118137

119138
@Override
120139
protected void onItemResponse(DiscoveryNode discoveryNode, NodeResponse nodeResponse) {
140+
nodeResponse.mustIncRef();
121141
synchronized (responses) {
122-
responses.add(nodeResponse);
142+
if ((task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) == false) {
143+
responses.add(nodeResponse);
144+
return;
145+
}
123146
}
147+
nodeResponse.decRef();
124148
}
125149

126150
@Override
@@ -134,7 +158,11 @@ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) {
134158
@Override
135159
protected CheckedConsumer<ActionListener<NodesResponse>, Exception> onCompletion() {
136160
// ref releases all happen-before here so no need to be synchronized
137-
return l -> newResponseAsync(task, request, responses, exceptions, l);
161+
return l -> {
162+
try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
163+
newResponseAsync(task, request, responses, exceptions, l);
164+
}
165+
};
138166
}
139167

140168
@Override
@@ -154,9 +182,11 @@ private Writeable.Reader<NodeResponse> nodeResponseReader(DiscoveryNode discover
154182
}
155183

156184
/**
157-
* Create a new {@link NodesResponse} (multi-node response).
185+
* Create a new {@link NodesResponse}. This method is executed on {@link #finalExecutor}.
158186
*
159-
* @param request The associated request.
187+
* @param request The request whose response we are constructing. {@link TransportNodesAction} may have already released all its
188+
* references to this object before calling this method, so it's up to individual implementations to retain their own
189+
* reference to the request if still needed here.
160190
* @param responses All successful node-level responses.
161191
* @param failures All node-level failures.
162192
* @return Never {@code null}.
@@ -166,7 +196,11 @@ private Writeable.Reader<NodeResponse> nodeResponseReader(DiscoveryNode discover
166196

167197
/**
168198
* Create a new {@link NodesResponse}, possibly asynchronously. The default implementation is synchronous and calls
169-
* {@link #newResponse(BaseNodesRequest, List, List)}
199+
* {@link #newResponse(BaseNodesRequest, List, List)}. This method is executed on {@link #finalExecutor}.
200+
*
201+
* @param request The request whose response we are constructing. {@link TransportNodesAction} may have already released all its
202+
* references to this object before calling this method, so it's up to individual implementations to retain their own
203+
* reference to the request if still needed here.
170204
*/
171205
protected void newResponseAsync(
172206
Task task,

server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
package org.elasticsearch.action.support.nodes;
1010

1111
import org.elasticsearch.ElasticsearchException;
12+
import org.elasticsearch.action.ActionListener;
1213
import org.elasticsearch.action.FailedNodeException;
1314
import org.elasticsearch.action.support.ActionFilters;
1415
import org.elasticsearch.action.support.PlainActionFuture;
15-
import org.elasticsearch.action.support.broadcast.node.TransportBroadcastByNodeActionTests;
16+
import org.elasticsearch.action.support.RefCountingListener;
17+
import org.elasticsearch.action.support.SubscribableListener;
1618
import org.elasticsearch.cluster.ClusterName;
1719
import org.elasticsearch.cluster.ClusterState;
1820
import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -55,6 +57,9 @@
5557
import java.util.Set;
5658
import java.util.concurrent.Executor;
5759
import java.util.concurrent.TimeUnit;
60+
import java.util.concurrent.atomic.AtomicInteger;
61+
import java.util.function.Function;
62+
import java.util.function.ObjLongConsumer;
5863

5964
import static java.util.Collections.emptyMap;
6065
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
@@ -118,7 +123,10 @@ public void testResponseAggregation() {
118123
final TestTransportNodesAction action = getTestTransportNodesAction();
119124

120125
final PlainActionFuture<TestNodesResponse> listener = new PlainActionFuture<>();
121-
action.execute(null, new TestNodesRequest(), listener);
126+
action.execute(null, new TestNodesRequest(), listener.delegateFailure((l, response) -> {
127+
assertTrue(response.getNodes().stream().allMatch(TestNodeResponse::hasReferences));
128+
l.onResponse(response);
129+
}));
122130
assertFalse(listener.isDone());
123131

124132
final Set<String> failedNodeIds = new HashSet<>();
@@ -127,7 +135,9 @@ public void testResponseAggregation() {
127135
for (CapturingTransport.CapturedRequest capturedRequest : transport.getCapturedRequestsAndClear()) {
128136
if (randomBoolean()) {
129137
successfulNodes.add(capturedRequest.node());
130-
transport.handleResponse(capturedRequest.requestId(), new TestNodeResponse(capturedRequest.node()));
138+
final var response = new TestNodeResponse(capturedRequest.node());
139+
transport.handleResponse(capturedRequest.requestId(), response);
140+
assertFalse(response.hasReferences()); // response is copied (via the wire protocol) so this instance is released
131141
} else {
132142
failedNodeIds.add(capturedRequest.node().getId());
133143
if (randomBoolean()) {
@@ -138,7 +148,16 @@ public void testResponseAggregation() {
138148
}
139149
}
140150

141-
TestNodesResponse response = listener.actionGet(10, TimeUnit.SECONDS);
151+
final TestNodesResponse response = listener.actionGet(10, TimeUnit.SECONDS);
152+
153+
final var allResponsesReleasedListener = new SubscribableListener<Void>();
154+
try (var listeners = new RefCountingListener(allResponsesReleasedListener)) {
155+
for (final var nodeResponse : response.getNodes()) {
156+
nodeResponse.addCloseListener(listeners.acquire());
157+
}
158+
}
159+
safeAwait(allResponsesReleasedListener);
160+
assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences));
142161

143162
for (TestNodeResponse nodeResponse : response.getNodes()) {
144163
assertThat(successfulNodes, Matchers.hasItem(nodeResponse.getNode()));
@@ -164,7 +183,7 @@ public void testResponsesReleasedOnCancellation() {
164183
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
165184
final PlainActionFuture<TestNodesResponse> listener = new PlainActionFuture<>();
166185
action.execute(cancellableTask, new TestNodesRequest(), listener.delegateResponse((l, e) -> {
167-
assert Thread.currentThread().getName().contains("[" + ThreadPool.Names.GENERIC + "]");
186+
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC);
168187
l.onFailure(e);
169188
}));
170189

@@ -173,13 +192,31 @@ public void testResponsesReleasedOnCancellation() {
173192
);
174193
Randomness.shuffle(capturedRequests);
175194

195+
final AtomicInteger liveResponseCount = new AtomicInteger();
196+
final Function<DiscoveryNode, TestNodeResponse> responseCreator = node -> {
197+
liveResponseCount.incrementAndGet();
198+
final var testNodeResponse = new TestNodeResponse(node);
199+
testNodeResponse.addCloseListener(ActionListener.running(liveResponseCount::decrementAndGet));
200+
return testNodeResponse;
201+
};
202+
203+
final ObjLongConsumer<TestNodeResponse> responseSender = (response, requestId) -> {
204+
try {
205+
// transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
206+
transport.getTransportResponseHandler(requestId).handleResponse(response);
207+
} finally {
208+
response.decRef();
209+
}
210+
};
211+
176212
final ReachabilityChecker reachabilityChecker = new ReachabilityChecker();
177213
final Runnable nextRequestProcessor = () -> {
178214
var capturedRequest = capturedRequests.remove(0);
179215
if (randomBoolean()) {
180-
// transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
181-
transport.getTransportResponseHandler(capturedRequest.requestId())
182-
.handleResponse(reachabilityChecker.register(new TestNodeResponse(capturedRequest.node())));
216+
responseSender.accept(
217+
reachabilityChecker.register(responseCreator.apply(capturedRequest.node())),
218+
capturedRequest.requestId()
219+
);
183220
} else {
184221
// handleRemoteError may de/serialize the exception, releasing it early, so just use handleLocalError
185222
transport.handleLocalError(
@@ -200,20 +237,23 @@ public void testResponsesReleasedOnCancellation() {
200237

201238
// responses captured before cancellation are now unreachable
202239
reachabilityChecker.ensureUnreachable();
240+
assertEquals(0, liveResponseCount.get());
203241

204242
while (capturedRequests.size() > 0) {
205243
// a response sent after cancellation is dropped immediately
206244
assertFalse(listener.isDone());
207245
nextRequestProcessor.run();
208246
reachabilityChecker.ensureUnreachable();
247+
assertEquals(0, liveResponseCount.get());
209248
}
210249

211250
expectThrows(TaskCancelledException.class, () -> listener.actionGet(10, TimeUnit.SECONDS));
251+
assertTrue(cancellableTask.isCancelled()); // keep task alive
212252
}
213253

214254
@BeforeClass
215255
public static void startThreadPool() {
216-
THREAD_POOL = new TestThreadPool(TransportBroadcastByNodeActionTests.class.getSimpleName());
256+
THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());
217257
}
218258

219259
@AfterClass
@@ -268,11 +308,9 @@ public void tearDown() throws Exception {
268308

269309
public TestTransportNodesAction getTestTransportNodesAction() {
270310
return new TestTransportNodesAction(
271-
THREAD_POOL,
272311
clusterService,
273312
transportService,
274313
new ActionFilters(Collections.emptySet()),
275-
TestNodesRequest::new,
276314
TestNodeRequest::new,
277315
THREAD_POOL.executor(ThreadPool.Names.GENERIC)
278316
);
@@ -302,11 +340,9 @@ private static class TestTransportNodesAction extends TransportNodesAction<
302340
TestNodeResponse> {
303341

304342
TestTransportNodesAction(
305-
ThreadPool threadPool,
306343
ClusterService clusterService,
307344
TransportService transportService,
308345
ActionFilters actionFilters,
309-
Writeable.Reader<TestNodesRequest> request,
310346
Writeable.Reader<TestNodeRequest> nodeRequest,
311347
Executor nodeExecutor
312348
) {
@@ -319,7 +355,7 @@ protected TestNodesResponse newResponse(
319355
List<TestNodeResponse> responses,
320356
List<FailedNodeException> failures
321357
) {
322-
return new TestNodesResponse(clusterService.getClusterName(), request, responses, failures);
358+
return new TestNodesResponse(clusterService.getClusterName(), responses, failures);
323359
}
324360

325361
@Override
@@ -350,7 +386,7 @@ private static class DataNodesOnlyTransportNodesAction extends TestTransportNode
350386
Writeable.Reader<TestNodeRequest> nodeRequest,
351387
Executor nodeExecutor
352388
) {
353-
super(threadPool, clusterService, transportService, actionFilters, request, nodeRequest, nodeExecutor);
389+
super(clusterService, transportService, actionFilters, nodeRequest, nodeExecutor);
354390
}
355391

356392
@Override
@@ -371,16 +407,8 @@ private static class TestNodesRequest extends BaseNodesRequest<TestNodesRequest>
371407

372408
private static class TestNodesResponse extends BaseNodesResponse<TestNodeResponse> {
373409

374-
private final TestNodesRequest request;
375-
376-
TestNodesResponse(
377-
ClusterName clusterName,
378-
TestNodesRequest request,
379-
List<TestNodeResponse> nodeResponses,
380-
List<FailedNodeException> failures
381-
) {
410+
TestNodesResponse(ClusterName clusterName, List<TestNodeResponse> nodeResponses, List<FailedNodeException> failures) {
382411
super(clusterName, nodeResponses, failures);
383-
this.request = request;
384412
}
385413

386414
@Override
@@ -425,6 +453,10 @@ public boolean hasReferences() {
425453
}
426454

427455
private static class TestNodeResponse extends BaseNodeResponse {
456+
457+
private final SubscribableListener<Void> onClose = new SubscribableListener<>();
458+
private final RefCounted refCounted = AbstractRefCounted.of(() -> onClose.onResponse(null));
459+
428460
TestNodeResponse() {
429461
this(mock(DiscoveryNode.class));
430462
}
@@ -436,6 +468,30 @@ private static class TestNodeResponse extends BaseNodeResponse {
436468
protected TestNodeResponse(StreamInput in) throws IOException {
437469
super(in);
438470
}
471+
472+
@Override
473+
public void incRef() {
474+
refCounted.incRef();
475+
}
476+
477+
@Override
478+
public boolean tryIncRef() {
479+
return refCounted.tryIncRef();
480+
}
481+
482+
@Override
483+
public boolean decRef() {
484+
return refCounted.decRef();
485+
}
486+
487+
@Override
488+
public boolean hasReferences() {
489+
return refCounted.hasReferences();
490+
}
491+
492+
void addCloseListener(ActionListener<Void> listener) {
493+
onClose.addListener(listener);
494+
}
439495
}
440496

441497
}

0 commit comments

Comments
 (0)