Skip to content

Preserve context in ResultDeduplicator #84038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/84038.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 84038
summary: Preserve context in `ResultDeduplicator`
area: Infra/Core
type: bug
issues:
- 84036
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

package org.elasticsearch.action;

import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ThreadContext;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -22,8 +24,13 @@
*/
public final class ResultDeduplicator<T, R> {

private final ThreadContext threadContext;
private final ConcurrentMap<T, CompositeListener> requests = ConcurrentCollections.newConcurrentMap();

public ResultDeduplicator(ThreadContext threadContext) {
this.threadContext = threadContext;
}

/**
* Ensures a given request not executed multiple times when another equal request is already in-flight.
* If the request is not yet known to the deduplicator it will invoke the passed callback with an {@link ActionListener}
Expand All @@ -35,7 +42,8 @@ public final class ResultDeduplicator<T, R> {
* @param callback Callback to be invoked with request and completion listener the first time the request is added to the deduplicator
*/
public void executeOnce(T request, ActionListener<R> listener, BiConsumer<T, ActionListener<R>> callback) {
ActionListener<R> completionListener = requests.computeIfAbsent(request, CompositeListener::new).addListener(listener);
ActionListener<R> completionListener = requests.computeIfAbsent(request, CompositeListener::new)
.addListener(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext));
if (completionListener != null) {
callback.accept(request, completionListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public class ShardStateAction {
private final ThreadPool threadPool;

// we deduplicate these shard state requests in order to avoid sending duplicate failed/started shard requests for a shard
private final ResultDeduplicator<TransportRequest, Void> remoteShardStateUpdateDeduplicator = new ResultDeduplicator<>();
private final ResultDeduplicator<TransportRequest, Void> remoteShardStateUpdateDeduplicator;

@Inject
public ShardStateAction(
Expand All @@ -94,6 +94,7 @@ public ShardStateAction(
this.transportService = transportService;
this.clusterService = clusterService;
this.threadPool = threadPool;
this.remoteShardStateUpdateDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());

transportService.registerRequestHandler(
SHARD_STARTED_ACTION_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ protected BlobStoreRepository(
this.namedXContentRegistry = namedXContentRegistry;
this.basePath = basePath;
this.maxSnapshotCount = MAX_SNAPSHOTS_SETTING.get(metadata.settings());
this.repoDataDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());
}

@Override
Expand Down Expand Up @@ -1866,7 +1867,7 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState)
* {@link #bestEffortConsistency} must be {@code false}, in which case we can assume that the {@link RepositoryData} loaded is
* unique for a given value of {@link #metadata} at any point in time.
*/
private final ResultDeduplicator<RepositoryMetadata, RepositoryData> repoDataDeduplicator = new ResultDeduplicator<>();
private final ResultDeduplicator<RepositoryMetadata, RepositoryData> repoDataDeduplicator;

private void doGetRepositoryData(ActionListener<RepositoryData> listener) {
// Retry loading RepositoryData in a loop in case we run into concurrent modifications of the repository.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements
private final Map<Snapshot, Map<ShardId, IndexShardSnapshotStatus>> shardSnapshots = new HashMap<>();

// A map of snapshots to the shardIds that we already reported to the master as failed
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator =
new ResultDeduplicator<>();
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator;

public SnapshotShardsService(
Settings settings,
Expand All @@ -97,6 +96,7 @@ public SnapshotShardsService(
this.transportService = transportService;
this.clusterService = clusterService;
this.threadPool = transportService.getThreadPool();
this.remoteFailedRequestDeduplicator = new ResultDeduplicator<>(threadPool.getThreadContext());
if (DiscoveryNode.canContainData(settings)) {
// this is only useful on the nodes that can hold data
clusterService.addListener(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ public class TaskCancellationService {
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
private final ResultDeduplicator<CancelRequest, Void> deduplicator = new ResultDeduplicator<>();
private final ResultDeduplicator<CancelRequest, Void> deduplicator;

public TaskCancellationService(TransportService transportService) {
this.transportService = transportService;
this.taskManager = transportService.getTaskManager();
this.deduplicator = new ResultDeduplicator<>(transportService.getThreadPool().getThreadContext());
transportService.registerRequestHandler(
BAN_PARENT_ACTION_NAME,
ThreadPool.Names.SAME,
Expand Down
13 changes: 10 additions & 3 deletions server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static org.hamcrest.Matchers.everyItem;
import static org.hamcrest.Matchers.in;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class TaskManagerTests extends ESTestCase {
private ThreadPool threadPool;
Expand Down Expand Up @@ -76,7 +77,9 @@ public void testResultsServiceRetryTotalTime() {
public void testTrackingChannelTask() throws Exception {
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of());
Set<Task> cancelledTasks = ConcurrentCollections.newConcurrentSet();
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
final var transportServiceMock = mock(TransportService.class);
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
@Override
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
assertThat(reason, equalTo("channel was closed"));
Expand Down Expand Up @@ -124,7 +127,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
public void testTrackingTaskAndCloseChannelConcurrently() throws Exception {
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of());
Set<CancellableTask> cancelledTasks = ConcurrentCollections.newConcurrentSet();
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
final var transportServiceMock = mock(TransportService.class);
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
@Override
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
assertTrue("task [" + task + "] was cancelled already", cancelledTasks.add(task));
Expand Down Expand Up @@ -180,7 +185,9 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF

public void testRemoveBansOnChannelDisconnects() throws Exception {
final TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Set.of());
taskManager.setTaskCancellationService(new TaskCancellationService(mock(TransportService.class)) {
final var transportServiceMock = mock(TransportService.class);
when(transportServiceMock.getThreadPool()).thenReturn(threadPool);
taskManager.setTaskCancellationService(new TaskCancellationService(transportServiceMock) {
@Override
void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;

Expand All @@ -29,27 +31,36 @@ public void testRequestDeduplication() throws Exception {
@Override
public void setParentTask(final TaskId taskId) {}
};
final ResultDeduplicator<TransportRequest, Void> deduplicator = new ResultDeduplicator<>();
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final ResultDeduplicator<TransportRequest, Void> deduplicator = new ResultDeduplicator<>(threadContext);
final SetOnce<ActionListener<Void>> listenerHolder = new SetOnce<>();
final var headerName = "thread-context-header";
final var headerGenerator = new AtomicInteger();
int iterationsPerThread = scaledRandomIntBetween(100, 1000);
Thread[] threads = new Thread[between(1, 4)];
Phaser barrier = new Phaser(threads.length + 1);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(() -> {
barrier.arriveAndAwaitAdvance();
for (int n = 0; n < iterationsPerThread; n++) {
deduplicator.executeOnce(request, new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
successCount.incrementAndGet();
}
final var headerValue = Integer.toString(headerGenerator.incrementAndGet());
try (var ignored = threadContext.stashContext()) {
threadContext.putHeader(headerName, headerValue);
deduplicator.executeOnce(request, new ActionListener<>() {
@Override
public void onResponse(Void aVoid) {
assertThat(threadContext.getHeader(headerName), equalTo(headerValue));
successCount.incrementAndGet();
}

@Override
public void onFailure(Exception e) {
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
}, (req, reqListener) -> listenerHolder.set(reqListener));
@Override
public void onFailure(Exception e) {
assertThat(threadContext.getHeader(headerName), equalTo(headerValue));
assertThat(e, sameInstance(failure));
failureCount.incrementAndGet();
}
}, (req, reqListener) -> listenerHolder.set(reqListener));
}
}
});
threads[i].start();
Expand Down