Skip to content

Refcount responses in TransportNodesAction #103254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
19 changes: 19 additions & 0 deletions libs/core/src/main/java/org/elasticsearch/core/Releasables.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicReference;

/** Utility methods to work with {@link Releasable}s. */
Expand Down Expand Up @@ -103,6 +104,24 @@ public String toString() {
};
}

/**
* Similar to {@link #wrap(Iterable)} except that it accepts an {@link Iterator} of releasables. The resulting resource must therefore
* only be released once.
*/
public static Releasable wrap(final Iterator<Releasable> releasables) {
return assertOnce(wrap(new Iterable<>() {
@Override
public Iterator<Releasable> iterator() {
return releasables;
}

@Override
public String toString() {
return releasables.toString();
}
}));
}

/** @see #wrap(Iterable) */
public static Releasable wrap(final Releasable... releasables) {
return new Releasable() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,27 @@ public String toString() {
assertEquals("wrapped[list]", wrapIterable.toString());
wrapIterable.close();
assertEquals(5, count.get());

final var wrapIterator = Releasables.wrap(new Iterator<>() {
final Iterator<Releasable> innerIterator = List.of(releasable, releasable, releasable).iterator();

@Override
public boolean hasNext() {
return innerIterator.hasNext();
}

@Override
public Releasable next() {
return innerIterator.next();
}

@Override
public String toString() {
return "iterator";
}
});
assertEquals("wrapped[iterator]", wrapIterator.toString());
wrapIterator.close();
assertEquals(8, count.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
Expand Down Expand Up @@ -96,6 +98,23 @@ protected void doExecute(Task task, NodesRequest request, ActionListener<NodesRe

final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());

{
addReleaseOnCancellationListener();
}

private void addReleaseOnCancellationListener() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incredibly if you write this inline in the anonymous constructor then the compiler blows up with a NPE: https://gradle-enterprise.elastic.co/s/pp3aym5qktpa2

if (task instanceof CancellableTask cancellableTask) {
cancellableTask.addListener(() -> {
final List<NodeResponse> drainedResponses;
synchronized (responses) {
drainedResponses = List.copyOf(responses);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if could add copy and remove elements in one iteration:

var iterator = responses.iterator();
while (iterator.hasNext()) {
    drainedResponses.add(iterator.next());
    iterator.remove();
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can, but adding elements one-by-one to a list potentially involves more allocations to grow the list incrementally, and removing items from a list iterator one-by-one is an O(N²) operation.

responses.clear();
}
Releasables.wrap(Iterators.map(drainedResponses.iterator(), r -> r::decRef)).close();
});
}
}

@Override
protected void sendItemRequest(DiscoveryNode discoveryNode, ActionListener<NodeResponse> listener) {
final var nodeRequest = newNodeRequest(request);
Expand All @@ -118,9 +137,14 @@ protected void sendItemRequest(DiscoveryNode discoveryNode, ActionListener<NodeR

@Override
protected void onItemResponse(DiscoveryNode discoveryNode, NodeResponse nodeResponse) {
nodeResponse.mustIncRef();
synchronized (responses) {
responses.add(nodeResponse);
if ((task instanceof CancellableTask cancellableTask && cancellableTask.isCancelled()) == false) {
responses.add(nodeResponse);
return;
}
}
nodeResponse.decRef();
}

@Override
Expand All @@ -134,7 +158,11 @@ protected void onItemFailure(DiscoveryNode discoveryNode, Exception e) {
@Override
protected CheckedConsumer<ActionListener<NodesResponse>, Exception> onCompletion() {
// ref releases all happen-before here so no need to be synchronized
return l -> newResponseAsync(task, request, responses, exceptions, l);
return l -> {
try (var ignored = Releasables.wrap(Iterators.map(responses.iterator(), r -> r::decRef))) {
newResponseAsync(task, request, responses, exceptions, l);
}
};
}

@Override
Expand All @@ -154,9 +182,11 @@ private Writeable.Reader<NodeResponse> nodeResponseReader(DiscoveryNode discover
}

/**
* Create a new {@link NodesResponse} (multi-node response).
* Create a new {@link NodesResponse}. This method is executed on {@link #finalExecutor}.
*
* @param request The associated request.
* @param request The request whose response we are constructing. {@link TransportNodesAction} may have already released all its
* references to this object before calling this method, so it's up to individual implementations to retain their own
* reference to the request if still needed here.
* @param responses All successful node-level responses.
* @param failures All node-level failures.
* @return Never {@code null}.
Expand All @@ -166,7 +196,11 @@ private Writeable.Reader<NodeResponse> nodeResponseReader(DiscoveryNode discover

/**
* Create a new {@link NodesResponse}, possibly asynchronously. The default implementation is synchronous and calls
* {@link #newResponse(BaseNodesRequest, List, List)}
* {@link #newResponse(BaseNodesRequest, List, List)}. This method is executed on {@link #finalExecutor}.
*
* @param request The request whose response we are constructing. {@link TransportNodesAction} may have already released all its
* references to this object before calling this method, so it's up to individual implementations to retain their own
* reference to the request if still needed here.
*/
protected void newResponseAsync(
Task task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
package org.elasticsearch.action.support.nodes;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.broadcast.node.TransportBroadcastByNodeActionTests;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -55,6 +57,9 @@
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.ObjLongConsumer;

import static java.util.Collections.emptyMap;
import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
Expand Down Expand Up @@ -118,7 +123,10 @@ public void testResponseAggregation() {
final TestTransportNodesAction action = getTestTransportNodesAction();

final PlainActionFuture<TestNodesResponse> listener = new PlainActionFuture<>();
action.execute(null, new TestNodesRequest(), listener);
action.execute(null, new TestNodesRequest(), listener.delegateFailure((l, response) -> {
assertTrue(response.getNodes().stream().allMatch(TestNodeResponse::hasReferences));
l.onResponse(response);
}));
assertFalse(listener.isDone());

final Set<String> failedNodeIds = new HashSet<>();
Expand All @@ -127,7 +135,9 @@ public void testResponseAggregation() {
for (CapturingTransport.CapturedRequest capturedRequest : transport.getCapturedRequestsAndClear()) {
if (randomBoolean()) {
successfulNodes.add(capturedRequest.node());
transport.handleResponse(capturedRequest.requestId(), new TestNodeResponse(capturedRequest.node()));
final var response = new TestNodeResponse(capturedRequest.node());
transport.handleResponse(capturedRequest.requestId(), response);
assertFalse(response.hasReferences()); // response is copied (via the wire protocol) so this instance is released
} else {
failedNodeIds.add(capturedRequest.node().getId());
if (randomBoolean()) {
Expand All @@ -138,7 +148,16 @@ public void testResponseAggregation() {
}
}

TestNodesResponse response = listener.actionGet(10, TimeUnit.SECONDS);
final TestNodesResponse response = listener.actionGet(10, TimeUnit.SECONDS);

final var allResponsesReleasedListener = new SubscribableListener<Void>();
try (var listeners = new RefCountingListener(allResponsesReleasedListener)) {
for (final var nodeResponse : response.getNodes()) {
nodeResponse.addCloseListener(listeners.acquire());
}
}
safeAwait(allResponsesReleasedListener);
assertTrue(response.getNodes().stream().noneMatch(TestNodeResponse::hasReferences));

for (TestNodeResponse nodeResponse : response.getNodes()) {
assertThat(successfulNodes, Matchers.hasItem(nodeResponse.getNode()));
Expand All @@ -164,7 +183,7 @@ public void testResponsesReleasedOnCancellation() {
final CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap());
final PlainActionFuture<TestNodesResponse> listener = new PlainActionFuture<>();
action.execute(cancellableTask, new TestNodesRequest(), listener.delegateResponse((l, e) -> {
assert Thread.currentThread().getName().contains("[" + ThreadPool.Names.GENERIC + "]");
assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.GENERIC);
l.onFailure(e);
}));

Expand All @@ -173,13 +192,31 @@ public void testResponsesReleasedOnCancellation() {
);
Randomness.shuffle(capturedRequests);

final AtomicInteger liveResponseCount = new AtomicInteger();
final Function<DiscoveryNode, TestNodeResponse> responseCreator = node -> {
liveResponseCount.incrementAndGet();
final var testNodeResponse = new TestNodeResponse(node);
testNodeResponse.addCloseListener(ActionListener.running(liveResponseCount::decrementAndGet));
return testNodeResponse;
};

final ObjLongConsumer<TestNodeResponse> responseSender = (response, requestId) -> {
try {
// transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
transport.getTransportResponseHandler(requestId).handleResponse(response);
} finally {
response.decRef();
}
};

final ReachabilityChecker reachabilityChecker = new ReachabilityChecker();
final Runnable nextRequestProcessor = () -> {
var capturedRequest = capturedRequests.remove(0);
if (randomBoolean()) {
// transport.handleResponse may de/serialize the response, releasing it early, so send the response straight to the handler
transport.getTransportResponseHandler(capturedRequest.requestId())
.handleResponse(reachabilityChecker.register(new TestNodeResponse(capturedRequest.node())));
responseSender.accept(
reachabilityChecker.register(responseCreator.apply(capturedRequest.node())),
capturedRequest.requestId()
);
} else {
// handleRemoteError may de/serialize the exception, releasing it early, so just use handleLocalError
transport.handleLocalError(
Expand All @@ -200,20 +237,23 @@ public void testResponsesReleasedOnCancellation() {

// responses captured before cancellation are now unreachable
reachabilityChecker.ensureUnreachable();
assertEquals(0, liveResponseCount.get());

while (capturedRequests.size() > 0) {
// a response sent after cancellation is dropped immediately
assertFalse(listener.isDone());
nextRequestProcessor.run();
reachabilityChecker.ensureUnreachable();
assertEquals(0, liveResponseCount.get());
}

expectThrows(TaskCancelledException.class, () -> listener.actionGet(10, TimeUnit.SECONDS));
assertTrue(cancellableTask.isCancelled()); // keep task alive
}

@BeforeClass
public static void startThreadPool() {
THREAD_POOL = new TestThreadPool(TransportBroadcastByNodeActionTests.class.getSimpleName());
THREAD_POOL = new TestThreadPool(TransportNodesActionTests.class.getSimpleName());
}

@AfterClass
Expand Down Expand Up @@ -268,11 +308,9 @@ public void tearDown() throws Exception {

public TestTransportNodesAction getTestTransportNodesAction() {
return new TestTransportNodesAction(
THREAD_POOL,
clusterService,
transportService,
new ActionFilters(Collections.emptySet()),
TestNodesRequest::new,
TestNodeRequest::new,
THREAD_POOL.executor(ThreadPool.Names.GENERIC)
);
Expand Down Expand Up @@ -302,11 +340,9 @@ private static class TestTransportNodesAction extends TransportNodesAction<
TestNodeResponse> {

TestTransportNodesAction(
ThreadPool threadPool,
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
Writeable.Reader<TestNodesRequest> request,
Writeable.Reader<TestNodeRequest> nodeRequest,
Executor nodeExecutor
) {
Expand All @@ -319,7 +355,7 @@ protected TestNodesResponse newResponse(
List<TestNodeResponse> responses,
List<FailedNodeException> failures
) {
return new TestNodesResponse(clusterService.getClusterName(), request, responses, failures);
return new TestNodesResponse(clusterService.getClusterName(), responses, failures);
}

@Override
Expand Down Expand Up @@ -350,7 +386,7 @@ private static class DataNodesOnlyTransportNodesAction extends TestTransportNode
Writeable.Reader<TestNodeRequest> nodeRequest,
Executor nodeExecutor
) {
super(threadPool, clusterService, transportService, actionFilters, request, nodeRequest, nodeExecutor);
super(clusterService, transportService, actionFilters, nodeRequest, nodeExecutor);
}

@Override
Expand All @@ -371,16 +407,8 @@ private static class TestNodesRequest extends BaseNodesRequest<TestNodesRequest>

private static class TestNodesResponse extends BaseNodesResponse<TestNodeResponse> {

private final TestNodesRequest request;

TestNodesResponse(
ClusterName clusterName,
TestNodesRequest request,
List<TestNodeResponse> nodeResponses,
List<FailedNodeException> failures
) {
TestNodesResponse(ClusterName clusterName, List<TestNodeResponse> nodeResponses, List<FailedNodeException> failures) {
super(clusterName, nodeResponses, failures);
this.request = request;
}

@Override
Expand Down Expand Up @@ -425,6 +453,10 @@ public boolean hasReferences() {
}

private static class TestNodeResponse extends BaseNodeResponse {

private final SubscribableListener<Void> onClose = new SubscribableListener<>();
private final RefCounted refCounted = AbstractRefCounted.of(() -> onClose.onResponse(null));

TestNodeResponse() {
this(mock(DiscoveryNode.class));
}
Expand All @@ -436,6 +468,30 @@ private static class TestNodeResponse extends BaseNodeResponse {
protected TestNodeResponse(StreamInput in) throws IOException {
super(in);
}

@Override
public void incRef() {
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
return refCounted.decRef();
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}

void addCloseListener(ActionListener<Void> listener) {
onClose.addListener(listener);
}
}

}