Skip to content

Run newShardSnapshotTask tasks concurrently #126478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126452.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126452
summary: Run `newShardSnapshotTask` tasks concurrently
area: Snapshot/Restore
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@

package org.elasticsearch.snapshots;

import org.elasticsearch.action.admin.cluster.snapshots.create.CreateSnapshotRequest;
import org.elasticsearch.action.admin.cluster.snapshots.create.TransportCreateSnapshotAction;
import org.elasticsearch.cluster.SnapshotsInProgress;
import org.elasticsearch.index.snapshots.IndexShardSnapshotStatus;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.snapshots.mockstore.MockRepository;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.disruption.NetworkDisruption;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPoolStats;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
Expand Down Expand Up @@ -89,4 +96,58 @@ public void testRetryPostingSnapshotStatusMessages() throws Exception {
assertThat(snapshotInfo.successfulShards(), equalTo(shards));
}, 30L, TimeUnit.SECONDS);
}

public void testStartSnapshotsConcurrently() {
internalCluster().startMasterOnlyNode();
final var dataNode = internalCluster().startDataOnlyNode();

final var repoName = randomIdentifier();
createRepository(repoName, "fs");

final var threadPool = internalCluster().getInstance(ThreadPool.class, dataNode);
final var snapshotThreadCount = threadPool.info(ThreadPool.Names.SNAPSHOT).getMax();

final var indexName = randomIdentifier();
final var shardCount = between(1, snapshotThreadCount * 2);
assertAcked(prepareCreate(indexName, 0, indexSettingsNoReplicas(shardCount)));
indexRandomDocs(indexName, scaledRandomIntBetween(50, 100));

final var snapshotExecutor = threadPool.executor(ThreadPool.Names.SNAPSHOT);
final var barrier = new CyclicBarrier(snapshotThreadCount + 1);
for (int i = 0; i < snapshotThreadCount; i++) {
snapshotExecutor.submit(() -> {
safeAwait(barrier);
safeAwait(barrier);
});
}

// wait until the snapshot threads are all blocked
safeAwait(barrier);

safeGet(
client().execute(
TransportCreateSnapshotAction.TYPE,
new CreateSnapshotRequest(TEST_REQUEST_TIMEOUT, repoName, randomIdentifier())
)
);

// one task for each snapshot thread (throttled) or shard (if fewer), plus one for runSyncTasksEagerly()
assertEquals(Math.min(snapshotThreadCount, shardCount) + 1, getSnapshotQueueLength(threadPool));

// release all the snapshot threads
safeAwait(barrier);

// wait for completion
safeAwait(ClusterServiceUtils.addMasterTemporaryStateListener(cs -> SnapshotsInProgress.get(cs).isEmpty()));
}

private static int getSnapshotQueueLength(ThreadPool threadPool) {
for (ThreadPoolStats.Stats stats : threadPool.stats().stats()) {
if (stats.name().equals(ThreadPool.Names.SNAPSHOT)) {
return stats.queue();
}
}

throw new AssertionError("threadpool stats for [" + ThreadPool.Names.SNAPSHOT + "] not found");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.seqno.SequenceNumbers;
Expand All @@ -54,7 +55,6 @@
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -87,6 +87,9 @@ public final class SnapshotShardsService extends AbstractLifecycleComponent impl
// A map of snapshots to the shardIds that we already reported to the master as failed
private final ResultDeduplicator<UpdateIndexShardSnapshotStatusRequest, Void> remoteFailedRequestDeduplicator;

// Runs the tasks that start each shard snapshot (e.g. acquiring the index commit)
private final ThrottledTaskRunner startShardSnapshotTaskRunner;

// Runs the tasks that promptly notify shards of aborted snapshots so that resources can be released ASAP
private final ThrottledTaskRunner notifyOnAbortTaskRunner;

Expand Down Expand Up @@ -114,6 +117,11 @@ public SnapshotShardsService(
threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(),
threadPool.generic()
);
this.startShardSnapshotTaskRunner = new ThrottledTaskRunner(
"start-shard-snapshots",
threadPool.info(ThreadPool.Names.SNAPSHOT).getMax(),
threadPool.executor(ThreadPool.Names.SNAPSHOT)
);
}

@Override
Expand Down Expand Up @@ -304,7 +312,6 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr

final var newSnapshotShards = shardSnapshots.computeIfAbsent(snapshot, s -> new HashMap<>());

final List<Runnable> shardSnapshotTasks = new ArrayList<>(shardsToStart.size());
for (final Map.Entry<ShardId, ShardGeneration> shardEntry : shardsToStart.entrySet()) {
final ShardId shardId = shardEntry.getKey();
final IndexShardSnapshotStatus snapshotStatus = IndexShardSnapshotStatus.newInitializing(shardEntry.getValue());
Expand All @@ -316,10 +323,36 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr
: "Found non-null, non-numeric shard generation ["
+ snapshotStatus.generation()
+ "] for snapshot with old-format compatibility";
shardSnapshotTasks.add(newShardSnapshotTask(shardId, snapshot, indexId, snapshotStatus, entry.version(), entry.startTime()));
final var shardSnapshotTask = newShardSnapshotTask(
shardId,
snapshot,
indexId,
snapshotStatus,
entry.version(),
entry.startTime()
);
startShardSnapshotTaskRunner.enqueueTask(new ActionListener<>() {
@Override
public void onResponse(Releasable releasable) {
try (releasable) {
shardSnapshotTask.run();
}
}

@Override
public void onFailure(Exception e) {
final var wrapperException = new IllegalStateException(
"impossible failure starting shard snapshot for " + shardId + " in " + snapshot,
e
);
logger.error(wrapperException.getMessage(), wrapperException);
assert false : wrapperException; // impossible
}
});
}

threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> shardSnapshotTasks.forEach(Runnable::run));
// apply some backpressure by reserving one SNAPSHOT thread for the startup work
startShardSnapshotTaskRunner.runSyncTasksEagerly(threadPool.executor(ThreadPool.Names.SNAPSHOT));
}

private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInProgress.Entry entry) {
Expand Down