Skip to content

Commit a5a9de4

Browse files
ankitkalagithub-actions[bot]
authored andcommitted
Fix for missing ShardReplicationTasks on new nodes (#497)
Signed-off-by: Ankit Kala <ankikala@amazon.com> Signed-off-by: Ankit Kala <ankikala@amazon.com> (cherry picked from commit 805f686)
1 parent 81ca400 commit a5a9de4

File tree

2 files changed

+85
-29
lines changed

2 files changed

+85
-29
lines changed

src/main/kotlin/org/opensearch/replication/task/index/IndexReplicationTask.kt

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
183183
ReplicationState.INIT_FOLLOW -> {
184184
log.info("Starting shard tasks")
185185
addIndexBlockForReplication()
186-
startShardFollowTasks(emptyMap())
186+
FollowingState(startNewOrMissingShardTasks())
187+
187188
}
188189
ReplicationState.FOLLOWING -> {
189190
if (currentTaskState is FollowingState) {
@@ -206,8 +207,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
206207
// Tasks need to be started
207208
state
208209
} else {
209-
state = pollShardTaskStatus((followingTaskState as FollowingState).shardReplicationTasks)
210-
followingTaskState = startMissingShardTasks((followingTaskState as FollowingState).shardReplicationTasks)
210+
state = pollShardTaskStatus()
211+
followingTaskState = FollowingState(startNewOrMissingShardTasks())
211212
when (state) {
212213
is MonitoringState -> {
213214
updateMetadata()
@@ -285,24 +286,7 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
285286
clusterService.addListener(this)
286287
}
287288

288-
private suspend fun startMissingShardTasks(shardTasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): IndexReplicationState {
289-
val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
290-
291-
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
292-
.map { task -> task.params as ShardReplicationParams }
293-
.collect(Collectors.toList())
294-
295-
val runningTasksForCurrentIndex = shardTasks.filter { entry -> runningShardTasks.find { task -> task.followerShardId == entry.key } != null}
296-
297-
val numMissingTasks = shardTasks.size - runningTasksForCurrentIndex.size
298-
if (numMissingTasks > 0) {
299-
log.info("Starting $numMissingTasks missing shard task(s)")
300-
return startShardFollowTasks(runningTasksForCurrentIndex)
301-
}
302-
return FollowingState(shardTasks)
303-
}
304-
305-
private suspend fun pollShardTaskStatus(shardTasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): IndexReplicationState {
289+
private suspend fun pollShardTaskStatus(): IndexReplicationState {
306290
val failedShardTasks = findAllReplicationFailedShardTasks(followerIndexName, clusterService.state())
307291
if (failedShardTasks.isNotEmpty()) {
308292
log.info("Failed shard tasks - ", failedShardTasks)
@@ -343,11 +327,16 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
343327
registerCloseListeners()
344328
val clusterState = clusterService.state()
345329
val persistentTasks = clusterState.metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
346-
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
330+
331+
val followerShardIds = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
332+
.map { shard -> shard.value.shardId }
333+
.stream().collect(Collectors.toSet())
334+
val runningShardTasksForIndex = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
347335
.map { task -> task.params as ShardReplicationParams }
336+
.filter {taskParam -> followerShardIds.contains(taskParam.followerShardId) }
348337
.collect(Collectors.toList())
349338

350-
if (runningShardTasks.size == 0) {
339+
if (runningShardTasksForIndex.size != followerShardIds.size) {
351340
return InitFollowState
352341
}
353342

@@ -696,19 +685,27 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
696685
}
697686
}
698687

699-
private suspend fun
700-
startShardFollowTasks(tasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): FollowingState {
688+
suspend fun startNewOrMissingShardTasks(): Map<ShardId, PersistentTask<ShardReplicationParams>> {
701689
assert(clusterService.state().routingTable.hasIndex(followerIndexName)) { "Can't find index $followerIndexName" }
702690
val shards = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
703-
val newTasks = shards.map {
691+
val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
692+
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
693+
.map { task -> task as PersistentTask<ShardReplicationParams> }
694+
.filter { task -> task.params!!.followerShardId.indexName == followerIndexName}
695+
.collect(Collectors.toMap(
696+
{t: PersistentTask<ShardReplicationParams> -> t.params!!.followerShardId},
697+
{t: PersistentTask<ShardReplicationParams> -> t}))
698+
699+
val tasks = shards.map {
704700
it.value.shardId
705701
}.associate { shardId ->
706-
val task = tasks.getOrElse(shardId) {
702+
val task = runningShardTasks.getOrElse(shardId) {
707703
startReplicationTask(ShardReplicationParams(leaderAlias, ShardId(leaderIndex, shardId.id), shardId))
708704
}
709705
return@associate shardId to task
710706
}
711-
return FollowingState(newTasks)
707+
708+
return tasks
712709
}
713710

714711
private suspend fun cancelRestore() {

src/test/kotlin/org/opensearch/replication/task/index/IndexReplicationTaskTests.kt

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ import org.opensearch.tasks.TaskManager
5555
import org.opensearch.test.ClusterServiceUtils
5656
import org.opensearch.test.ClusterServiceUtils.setState
5757
import org.opensearch.test.OpenSearchTestCase
58-
import org.opensearch.test.OpenSearchTestCase.assertBusy
5958
import org.opensearch.threadpool.TestThreadPool
6059
import java.util.*
6160
import java.util.concurrent.TimeUnit
@@ -150,6 +149,66 @@ class IndexReplicationTaskTests : OpenSearchTestCase() {
150149

151150
}
152151

152+
fun testStartNewShardTasks() = runBlocking {
153+
val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask())
154+
var taskManager = Mockito.mock(TaskManager::class.java)
155+
replicationTask.setPersistent(taskManager)
156+
var rc = ReplicationContext(followerIndex)
157+
var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY)
158+
replicationTask.setReplicationMetadata(rm)
159+
160+
// Build cluster state
161+
val indices: MutableList<String> = ArrayList()
162+
indices.add(followerIndex)
163+
var metadata = Metadata.builder()
164+
.put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
165+
.put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0))
166+
.build()
167+
var routingTableBuilder = RoutingTable.builder()
168+
.addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX))
169+
.addAsNew(metadata.index(followerIndex))
170+
var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build()
171+
setState(clusterService, newClusterState)
172+
173+
// Try starting shard tasks
174+
val shardTasks = replicationTask.startNewOrMissingShardTasks()
175+
assertThat(shardTasks.size == 2).isTrue
176+
}
177+
178+
179+
fun testStartMissingShardTasks() = runBlocking {
180+
val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask())
181+
var taskManager = Mockito.mock(TaskManager::class.java)
182+
replicationTask.setPersistent(taskManager)
183+
var rc = ReplicationContext(followerIndex)
184+
var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY)
185+
replicationTask.setReplicationMetadata(rm)
186+
187+
// Build cluster state
188+
val indices: MutableList<String> = ArrayList()
189+
indices.add(followerIndex)
190+
191+
val tasks = PersistentTasksCustomMetadata.builder()
192+
var sId = ShardId(Index(followerIndex, "_na_"), 0)
193+
tasks.addTask<PersistentTaskParams>( "replication:0", ShardReplicationExecutor.TASK_NAME, ShardReplicationParams("remoteCluster", sId, sId),
194+
PersistentTasksCustomMetadata.Assignment("other_node_", "test assignment on other node"))
195+
196+
var metadata = Metadata.builder()
197+
.put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
198+
.put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0))
199+
.putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build())
200+
.build()
201+
var routingTableBuilder = RoutingTable.builder()
202+
.addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX))
203+
.addAsNew(metadata.index(followerIndex))
204+
var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build()
205+
setState(clusterService, newClusterState)
206+
207+
// Try starting shard tasks
208+
val shardTasks = replicationTask.startNewOrMissingShardTasks()
209+
assertThat(shardTasks.size == 2).isTrue
210+
}
211+
153212
private fun createIndexReplicationTask() : IndexReplicationTask {
154213
var threadPool = TestThreadPool("IndexReplicationTask")
155214
//Hack Alert : Though it is meant to force rejection , this is to make overallTaskScope not null

0 commit comments

Comments
 (0)